""" Noisy state builder for block-AR SAD training. Noising levels per token (uniformly sampled): 0 = clean (keep leaf embedding) 1..L-1 = ancestor level l (sample from LUT, use learnable ancestor embedding) L = mask token AncestorTable provides: - Fixed LUT (indices + probs): which ancestor cluster each token maps to - Learnable ancestor embeddings: the actual embedding used as noisy input """ from typing import List, Optional, Tuple import torch import torch.nn as nn from src.diffusion.ancestor_table import AncestorTable class NoisyStateBuilder(nn.Module): """ Builds noisy embeddings for vectorized block-AR training. Args: vocab_size: V mask_token_id: id of [MASK] in the leaf vocabulary temp: kept for API compat, unused (temp is baked into LUT probs) top_k_per_level: kept for API compat, unused """ def __init__( self, vocab_size: int, mask_token_id: int, temp: float = 1.0, top_k_per_level: Optional[List[int]] = None, use_soft_expected: bool = True, ): super().__init__() self.vocab_size = vocab_size self.mask_token_id = mask_token_id # temp / top_k_per_level kept for API compat but not used def sample_levels_uniform( self, B: int, S: int, num_total_states: int, device: torch.device ) -> torch.Tensor: """ Sample per-token levels uniformly from {0, 1, ..., num_total_states-1}. num_total_states = num_ancestor_levels + 1 + 1 = (mask_level) where mask_level = num_ancestor_levels + 1 Returns: levels: [B, S] int64 """ return torch.randint(0, num_total_states, (B, S), device=device) @staticmethod def sample_t( B: int, device: torch.device, eps: float = 1e-3, low_discrepancy: bool = False, rank: int = 0, world_size: int = 1, ) -> torch.Tensor: """ Sample per-sequence noise level t ~ U[eps, 1-eps] (shape [B]). Change 3 (MDLM): low-discrepancy stratified sampling. A single shared phase u is drawn per step; the batch covers the grid {u, u+1/B, ...}. In DDP, the phase is offset by rank/(world_size*B) so that all ranks together cover finer strata (analogous to Kingma's VDM). Change 2 (BD3-LMs / Soft-Masked Diffusion): clamp to [eps, 1-eps] to avoid high-variance ELBO gradients at extreme mask rates (t→0 or t→1). """ if low_discrepancy: # Fresh shared phase every step — must not be cached across steps. u = torch.rand((), device=device) if world_size > 1: # Offset phase per rank so disjoint strata across the global batch. u = (u + rank / (world_size * B)) % 1.0 t = torch.arange(B, device=device, dtype=torch.float32) / B t = (t + u) % 1.0 else: t = torch.rand(B, device=device) # Clamp to [eps, 1-eps]: avoid high-variance ELBO at extreme mask rates. return (1 - 2 * eps) * t + eps def sample_levels_hdlm( self, t: torch.Tensor, S: int, num_ancestor_levels: int, gamma: float = 1.0, ) -> torch.Tensor: """ HDLM 3-state (clean / ancestor / mask) hierarchical schedule, γ=1 form from HierarchicalDiffusion.get_alpha_betapi: α_t = 1 - t # P(clean, keep original) c_t = -(1-t) * log(1-t) # P(ancestor state) m_t = t + (1-t) * log(1-t) # P(mask) The three probabilities sum to 1 (with tiny numerical renormalization). Each sequence has its own t (shape [B]); each token within the sequence is drawn i.i.d. from the 3-way categorical. If the token lands on the ancestor state and num_ancestor_levels > 1, c_t is split evenly among levels 1..L. Args: t: [B] in (0, 1), from sample_t() S: sequence length num_ancestor_levels: L = ancestor_table.num_levels (typically 1) gamma: only γ=1 supported Returns: levels: [B, S] int64, values in {0, 1..L, L+1=mask_level} """ if gamma != 1.0: raise NotImplementedError("sample_levels_hdlm only supports γ=1") assert num_ancestor_levels >= 1 B = t.shape[0] device = t.device eps = 1e-8 one_m_t = (1.0 - t).clamp(min=eps) # [B] log_1m_t = one_m_t.log() # [B] <= 0 alpha_t = one_m_t # P(clean) c_t = -one_m_t * log_1m_t # P(ancestor-any) >= 0 m_t = (t + one_m_t * log_1m_t).clamp(min=0.0) # P(mask) >= 0 # Renormalize for safety (should already sum to ~1) total = alpha_t + c_t + m_t alpha_t = alpha_t / total c_t = c_t / total m_t = m_t / total # Split c_t evenly among ancestor levels 1..L c_per_level = c_t / num_ancestor_levels # [B] # Build per-sample probability over categories [clean, anc_1..anc_L, mask] probs = torch.stack( [alpha_t] + [c_per_level] * num_ancestor_levels + [m_t], dim=-1, ) # [B, 2+L] # Per-token multinomial. Category index i maps 1-to-1 to level value: # 0 -> clean (level 0) # 1..L -> ancestor levels 1..L # L+1 -> mask_level (= num_ancestor_levels + 1 in NoisyStateBuilder) probs_flat = ( probs.unsqueeze(1).expand(B, S, -1).reshape(B * S, -1) ) sampled = torch.multinomial(probs_flat, num_samples=1).squeeze(-1) return sampled.view(B, S).long() def build_noisy_embeddings( self, input_ids: torch.Tensor, levels: torch.Tensor, ancestor_table: AncestorTable, leaf_embeddings: torch.Tensor, mask_embedding: torch.Tensor, # kept for API compat (ignored): hierarchy=None, ) -> Tuple[torch.Tensor, torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]: """ Build noisy embeddings from LUT-based ancestor table. For each position: level 0: leaf_embeddings[token_id] (clean) level 1..L: learnable ancestor embedding sampled via LUT level L+1: mask_embedding (fully masked) Args: input_ids: [B, S] int64 levels: [B, S] int64 (0 .. num_ancestor_levels+1) ancestor_table: AncestorTable leaf_embeddings: [V, d] mask_embedding: [d] hierarchy: ignored (kept for API compat) Returns: noisy_embs: [B, S, d] ancestor_log_probs: [B, S] log-prob of chosen ancestor (0 at clean/mask) ancestor_probs_per_lvl: list of None (not needed downstream) corrupt_mask: [B, S] bool – True at positions level >= 1 """ B, S = input_ids.shape dtype = leaf_embeddings.dtype device = input_ids.device # mask level = num_ancestor_levels + 1 mask_level = ancestor_table.num_levels + 1 # Start with clean embeddings everywhere noisy_embs = leaf_embeddings[input_ids].clone() # [B, S, d] ancestor_log_probs = torch.zeros(B, S, device=device, dtype=dtype) corrupt_mask = torch.zeros(B, S, dtype=torch.bool, device=device) ancestor_probs_per_lvl: List[Optional[torch.Tensor]] = [] # Apply mask mask_pos = (levels == mask_level) if mask_pos.any(): noisy_embs[mask_pos] = mask_embedding.to(dtype) corrupt_mask[mask_pos] = True for l in range(1, mask_level): # ancestor levels 1..mask_level-1 pos_l = (levels == l) if not pos_l.any(): ancestor_probs_per_lvl.append(None) continue flat_ids = input_ids[pos_l] # [N] N = flat_ids.shape[0] # Sample ancestor index via LUT multinomial lut_idx = ancestor_table.lut_indices(l)[flat_ids] # [N, top_k] lut_prob = ancestor_table.lut_probs(l)[flat_ids] # [N, top_k] sampled_local = torch.multinomial(lut_prob, num_samples=1).squeeze(1) # [N] sampled_global = lut_idx[ torch.arange(N, device=device), sampled_local ] # [N] # Learnable ancestor embedding anc_emb = ancestor_table.ancestor_embeddings(l) # [K, d] noisy_embs[pos_l] = anc_emb[sampled_global].to(dtype) # Log-prob of chosen ancestor (from LUT probs) chosen_lp = lut_prob[ torch.arange(N, device=device), sampled_local ].clamp(min=1e-8).log() ancestor_log_probs[pos_l] = chosen_lp.to(dtype) corrupt_mask[pos_l] = True ancestor_probs_per_lvl.append(None) # not needed downstream return noisy_embs, ancestor_log_probs, ancestor_probs_per_lvl, corrupt_mask