""" Adaptive Skip / Backoff Sampler for SAD. Two inference modes: 1. `AdjacentOnlySampler`: baseline – decode one level at a time (no skip). 2. `AdaptiveSkipSampler`: main method – token-wise adaptive exit. Confidence metrics supported: - "max_prob": max_k p_leaf[k] - "neg_entropy": (H_max - H(p_leaf)) / H_max (normalized negative entropy) State per position (AdaptiveSkipSampler): - resolved: bool [B, S] – True if token is finalized - current_ids: [B, S] – current token ids (may be coarse cluster ids) - exit_levels: [B, S] – which level each token exited at (for logging) APPROXIMATION NOTE: The sampler currently re-runs the full model on all positions each step, even for resolved tokens. Resolved token embeddings are still part of the context (they serve as conditioning). This is correct but not optimal for speed. A KV-cache variant that skips resolved positions is left as TODO. """ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import tqdm.auto as tqdm def compute_confidence( p_leaf: torch.Tensor, metric: str = "neg_entropy", ) -> torch.Tensor: """ Compute per-position confidence score. Args: p_leaf: [B, S, V] probability distribution over leaf tokens metric: "max_prob" or "neg_entropy" Returns: conf: [B, S] confidence scores in [0, 1] """ if metric == "max_prob": return p_leaf.max(dim=-1).values # [B, S] elif metric == "neg_entropy": V = p_leaf.shape[-1] eps = 1e-8 H = -(p_leaf * (p_leaf + eps).log()).sum(dim=-1) # [B, S] H_max = torch.log(torch.tensor(float(V), device=p_leaf.device)) # Normalize: 0 = max entropy (least confident), 1 = zero entropy conf = 1.0 - H / H_max.clamp(min=1e-8) return conf.clamp(0.0, 1.0) # [B, S] else: raise ValueError(f"Unknown confidence metric: {metric}") class SADSampler(nn.Module): """ Token-wise adaptive skip / backoff sampler. At each denoising step: 1. Run model -> leaf logits + hidden state h 2. Compute per-level confidence via h × prototype inner products 3. For each position, find the finest level with confidence >= threshold 4. If leaf conf >= tau_0: finalize the token 5. Else: update state to the appropriate coarse representation Args: model: SADModel hierarchy: SoftAncestorHierarchy or HardAncestorHierarchy tokenizer: HuggingFace tokenizer confidence_metric: "neg_entropy" or "max_prob" thresholds: list of per-level thresholds tau_0..tau_{L-1} tau_0 is for leaf, tau_1 for level-1, ... Higher threshold = harder to finalize at that level. freeze_resolved: if True, finalized tokens are kept frozen in context. """ def __init__( self, model: nn.Module, hierarchy, tokenizer, confidence_metric: str = "neg_entropy", thresholds: Optional[List[float]] = None, freeze_resolved: bool = True, ): super().__init__() self.model = model self.hierarchy = hierarchy self.tokenizer = tokenizer self.confidence_metric = confidence_metric self.freeze_resolved = freeze_resolved num_levels = hierarchy.num_levels if thresholds is None: # Default: leaf threshold 0.8, coarser levels 0.5 thresholds = [0.8] + [0.5] * (num_levels - 1) assert len(thresholds) >= num_levels self.thresholds = thresholds @torch.no_grad() def generate( self, num_samples: int, num_steps: int, max_length: int = 512, device=None, show_progress: bool = True, random_coarse_init: bool = False, ) -> Tuple[torch.Tensor, Dict]: """ Generate `num_samples` sequences of length `max_length`. Args: num_samples: B num_steps: number of denoising iterations max_length: S device: target device show_progress: show tqdm bar random_coarse_init: if True, initialize from random top-level state instead of all-mask (diverse coarse init mode). Returns: token_ids: [B, S] final generated token ids stats: dict with 'exit_levels' [B, S], 'resolved_over_steps' list """ if device is None: device = next(self.model.parameters()).device B, S = num_samples, max_length mask_id = self.tokenizer.mask_token_id num_levels = self.hierarchy.num_levels vocab_size = self.model.vocab_size # Offset table: cluster ids at level l are stored as vocab_size + sum(K_1..K_{l-1}) + k # offsets[l] = start of level-l cluster ids in the extended id space offsets = [0] # level 0: leaf ids start at 0 for l in range(1, num_levels): offsets.append(offsets[-1] + self.hierarchy.level_sizes[l - 1]) # offsets[1] = vocab_size (first cluster level starts here) # offsets[l] = vocab_size + K1 + ... + K_{l-1} leaf_emb = self.model.get_leaf_embeddings() # [V, d] mask_emb = leaf_emb[mask_id] # [d] def ids_to_embeddings(ids: torch.Tensor) -> torch.Tensor: """ Convert [B, S] ids (leaf / cluster offset / mask) → [B, S, d] embeddings. mask_id → mask_emb leaf id (0..V-1) → leaf_emb[id] (excluding mask_id) cluster id at lvl l → prototypes[l-1][id - offsets[l]] """ embs = torch.zeros(B, S, leaf_emb.shape[-1], device=device, dtype=leaf_emb.dtype) # mask positions (checked first to avoid overlap with leaf range) mask_pos = (ids == mask_id) if mask_pos.any(): embs[mask_pos] = mask_emb # leaf positions leaf_pos = (ids < vocab_size) & ~mask_pos if leaf_pos.any(): embs[leaf_pos] = leaf_emb[ids[leaf_pos]] # cluster levels for l in range(1, num_levels): lo = offsets[l] hi = lo + self.hierarchy.level_sizes[l] cluster_pos = (ids >= lo) & (ids < hi) if cluster_pos.any(): proto = self.hierarchy.prototypes[l - 1] # [K_l, d] embs[cluster_pos] = proto[ids[cluster_pos] - lo] return embs # ---- Initialize state: all MASK (level = num_levels) ---- current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device) # current_level tracks the coarseness of each token's current state: # num_levels = MASK # 1..num_levels-1 = intermediate cluster level l # 0 = finalized leaf (resolved) current_level = torch.full((B, S), num_levels, dtype=torch.long, device=device) resolved = torch.zeros(B, S, dtype=torch.bool, device=device) exit_levels = torch.full((B, S), num_levels, dtype=torch.long, device=device) resolved_over_steps = [] for step_i in tqdm.trange(num_steps, desc="sad", disable=not show_progress): # ---- id → embedding → forward ---- input_embs = ids_to_embeddings(current_ids) # [B, S, d] leaf_logits, h = self.model(input_embeddings=input_embs) # [B,S,V], [B,S,d] leaf_logits[..., mask_id] = float('-inf') p_leaf = leaf_logits.softmax(dim=-1) # [B, S, V] # ---- Per-level confidence: h @ proto_l (cosine sim) ---- h_norm = F.normalize(h.float(), dim=-1) # [B, S, d] conf_per_level = [] # conf_per_level[l-1]: [B, S] for level l proto_norm_cache = [] for l in range(1, num_levels): proto_norm = F.normalize( self.hierarchy.prototypes[l - 1].float(), dim=-1) # [K_l, d] sim = h_norm @ proto_norm.T # [B, S, K_l] p_l = sim.softmax(dim=-1) conf_per_level.append(compute_confidence(p_l, self.confidence_metric)) proto_norm_cache.append(proto_norm) conf_leaf = compute_confidence(p_leaf, self.confidence_metric) # [B, S] # ---- Update: finest first, stop at first hit, no backtracking ---- # A token at current_level L can only move to level l < L (finer). active = ~resolved new_ids = current_ids.clone() updated = torch.zeros(B, S, dtype=torch.bool, device=device) # Check leaf first (finest) can_leaf = active & ~updated & (current_level > 0) finalize = can_leaf & (conf_leaf >= self.thresholds[0]) if finalize.any(): new_ids[finalize] = p_leaf.argmax(dim=-1)[finalize] current_level[finalize] = 0 resolved[finalize] = True exit_levels[finalize] = 0 updated |= finalize # Then check intermediate levels from finest (l=1) to coarsest (l=num_levels-1) for l in range(1, num_levels): # Only consider positions not yet updated this step, # and whose current level is coarser than l (current_level > l) can_move = active & ~updated & (current_level > l) assign = can_move & (conf_per_level[l - 1] >= self.thresholds[l]) if assign.any(): cluster_ids = ( (h_norm[assign] @ proto_norm_cache[l - 1].T) .argmax(dim=-1) + offsets[l] ) new_ids[assign] = cluster_ids current_level[assign] = l updated |= assign current_ids = new_ids resolved_over_steps.append(resolved.float().mean().item()) if resolved.all(): if show_progress: print(f"All tokens resolved at step {step_i + 1}") break # Final: force any unresolved to argmax leaf unresolved = ~resolved if unresolved.any(): input_embs = ids_to_embeddings(current_ids) leaf_logits, _ = self.model(input_embeddings=input_embs) leaf_logits[..., mask_id] = float('-inf') current_ids[unresolved] = leaf_logits.argmax(dim=-1)[unresolved] stats = { "exit_levels": exit_levels.cpu(), "resolved_over_steps": resolved_over_steps, } return current_ids.cpu(), stats class SADBlockSampler(nn.Module): """ Block-by-block adaptive skip sampler for Block-AR inference. Generates one block at a time in autoregressive order. Within each block the same adaptive-skip logic as SADSampler applies (finest-first confidence check via h × prototype inner products, no backtracking). Previously resolved blocks are kept as clean causal context via model.forward_causal(). Args: model: SADModel hierarchy: SoftAncestorHierarchy or HardAncestorHierarchy tokenizer: HuggingFace tokenizer (needs mask_token_id) confidence_metric: "neg_entropy" or "max_prob" thresholds: per-level thresholds [tau_leaf, tau_l1, tau_l2, ...] num_steps_per_block: denoising steps per block """ def __init__( self, model: nn.Module, hierarchy, tokenizer, confidence_metric: str = "neg_entropy", thresholds: Optional[List[float]] = None, num_steps_per_block: int = 20, ): super().__init__() self.model = model self.hierarchy = hierarchy self.tokenizer = tokenizer self.confidence_metric = confidence_metric self.num_steps_per_block = num_steps_per_block num_levels = hierarchy.num_levels if thresholds is None: thresholds = [0.8] + [0.5] * (num_levels - 1) assert len(thresholds) >= num_levels self.thresholds = thresholds @torch.no_grad() def generate( self, num_samples: int, max_length: int = 512, device=None, show_progress: bool = True, ) -> Tuple[torch.Tensor, dict]: """ Generate `num_samples` sequences of length `max_length`. Returns: token_ids: [B, S] stats: dict with 'exit_levels' [B, S] """ if device is None: device = next(self.model.parameters()).device block_size = self.model.block_size assert max_length % block_size == 0, \ f"max_length ({max_length}) must be divisible by block_size ({block_size})" B = num_samples S = max_length num_blocks = S // block_size mask_id = self.tokenizer.mask_token_id num_levels = self.hierarchy.num_levels vocab_size = self.model.vocab_size # Offsets for extended id space (same as AdaptiveSkipSampler) offsets = [0] for l in range(1, num_levels): offsets.append(offsets[-1] + self.hierarchy.level_sizes[l - 1]) leaf_emb = self.model.get_leaf_embeddings() # [V, d] mask_emb = leaf_emb[mask_id] # [d] def ids_to_embeddings(ids: torch.Tensor) -> torch.Tensor: """[B, S] ids → [B, S, d] embeddings.""" out = torch.zeros(B, ids.shape[1], leaf_emb.shape[-1], device=device, dtype=leaf_emb.dtype) mask_pos = (ids == mask_id) if mask_pos.any(): out[mask_pos] = mask_emb leaf_pos = (ids < vocab_size) & ~mask_pos if leaf_pos.any(): out[leaf_pos] = leaf_emb[ids[leaf_pos]] for l in range(1, num_levels): lo = offsets[l] hi = lo + self.hierarchy.level_sizes[l] cluster_pos = (ids >= lo) & (ids < hi) if cluster_pos.any(): proto = self.hierarchy.prototypes[l - 1] out[cluster_pos] = proto[ids[cluster_pos] - lo] return out # State for the full sequence current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device) exit_levels = torch.full((B, S), num_levels, dtype=torch.long, device=device) block_iter = range(num_blocks) if show_progress: from tqdm.auto import tqdm as _tqdm block_iter = _tqdm(block_iter, desc="sad_block") for blk in block_iter: blk_start = blk * block_size blk_end = blk_start + block_size # Block-level state blk_current = current_ids[:, blk_start:blk_end].clone() # [B, bs] blk_level = torch.full((B, block_size), num_levels, dtype=torch.long, device=device) blk_resolved = torch.zeros(B, block_size, dtype=torch.bool, device=device) for step_i in range(self.num_steps_per_block): # Build full-sequence embeddings: resolved prefix + current block state current_ids[:, blk_start:blk_end] = blk_current input_embs = ids_to_embeddings(current_ids) # [B, S, d] # Use forward_causal so the block attends to itself + earlier blocks leaf_logits, h = self.model.forward_causal(input_embs) # [B, S, V], [B, S, d] # Focus on current block only blk_logits = leaf_logits[:, blk_start:blk_end, :] # [B, bs, V] blk_h = h[:, blk_start:blk_end, :] # [B, bs, d] blk_logits[..., mask_id] = float('-inf') p_leaf = blk_logits.softmax(dim=-1) # [B, bs, V] # Per-level confidence from h h_norm = F.normalize(blk_h.float(), dim=-1) # [B, bs, d] conf_per_level = [] proto_norm_cache = [] for l in range(1, num_levels): proto_norm = F.normalize( self.hierarchy.prototypes[l - 1].float(), dim=-1) sim = h_norm @ proto_norm.T # [B, bs, K_l] p_l = sim.softmax(dim=-1) conf_per_level.append( compute_confidence(p_l, self.confidence_metric)) proto_norm_cache.append(proto_norm) conf_leaf = compute_confidence(p_leaf, self.confidence_metric) # [B, bs] # Adaptive skip: finest first, no backtracking active = ~blk_resolved new_blk = blk_current.clone() updated = torch.zeros(B, block_size, dtype=torch.bool, device=device) # Leaf (finest) can_leaf = active & ~updated & (blk_level > 0) finalize = can_leaf & (conf_leaf >= self.thresholds[0]) if finalize.any(): new_blk[finalize] = p_leaf.argmax(dim=-1)[finalize] blk_level[finalize] = 0 blk_resolved[finalize] = True exit_levels[:, blk_start:blk_end][finalize] = 0 updated |= finalize # Intermediate levels for l in range(1, num_levels): can_move = active & ~updated & (blk_level > l) assign = can_move & (conf_per_level[l - 1] >= self.thresholds[l]) if assign.any(): cluster_ids = ( (h_norm[assign] @ proto_norm_cache[l - 1].T) .argmax(dim=-1) + offsets[l] ) new_blk[assign] = cluster_ids blk_level[assign] = l updated |= assign blk_current = new_blk if blk_resolved.all(): break # Force-resolve any remaining tokens in the block unresolved = ~blk_resolved if unresolved.any(): current_ids[:, blk_start:blk_end] = blk_current input_embs = ids_to_embeddings(current_ids) blk_logits, _ = self.model.forward_causal(input_embs) blk_logits = blk_logits[:, blk_start:blk_end, :] blk_logits[..., mask_id] = float('-inf') blk_current[unresolved] = blk_logits.argmax(dim=-1)[unresolved] exit_levels[:, blk_start:blk_end][unresolved] = 0 # Commit resolved block to current_ids (leaf ids only) current_ids[:, blk_start:blk_end] = blk_current stats = {"exit_levels": exit_levels.cpu()} return current_ids.cpu(), stats class AdjacentOnlySampler(nn.Module): """ Baseline sampler: always decode level by level (no skipping). Each iteration takes the model's argmax prediction. No adaptive exit. Equivalent to MDLM-style single-step argmax decoding. Args: model: SADModel tokenizer: HuggingFace tokenizer t_eps: timestep epsilon """ def __init__(self, model: nn.Module, tokenizer, t_eps: float = 1e-4): super().__init__() self.model = model self.tokenizer = tokenizer self.t_eps = t_eps @torch.no_grad() def generate( self, num_samples: int, num_steps: int, max_length: int = 512, device=None, show_progress: bool = True, ) -> Tuple[torch.Tensor, Dict]: if device is None: device = next(self.model.parameters()).device B, S = num_samples, max_length mask_id = self.tokenizer.mask_token_id leaf_emb = self.model.get_leaf_embeddings() # [V, d] mask_emb = leaf_emb[mask_id] # [d] def to_emb(ids: torch.Tensor) -> torch.Tensor: # ids: [B, S], only leaf ids or mask_id embs = leaf_emb[ids.clamp(0, leaf_emb.shape[0] - 1)] mask_pos = (ids == mask_id) if mask_pos.any(): embs[mask_pos] = mask_emb return embs current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device) for step_i in tqdm.trange(num_steps, desc="Adjacent sampling", disable=not show_progress): leaf_logits, _ = self.model(input_embeddings=to_emb(current_ids)) leaf_logits[..., mask_id] = float('-inf') # MDLM-style: only update masked positions is_masked = (current_ids == mask_id) best_tokens = leaf_logits.argmax(dim=-1) # [B, S] # Stochastically unmask some positions unmask_prob = 1.0 / max(num_steps - step_i, 1) should_unmask = torch.rand_like(current_ids.float()) < unmask_prob update = is_masked & should_unmask current_ids = torch.where(update, best_tokens, current_ids) # Final pass: decode remaining masks is_masked = (current_ids == mask_id) if is_masked.any(): leaf_logits, _ = self.model(input_embeddings=to_emb(current_ids)) leaf_logits[..., mask_id] = float('-inf') final_tokens = leaf_logits.argmax(dim=-1) current_ids = torch.where(is_masked, final_tokens, current_ids) stats = { "exit_levels": torch.zeros(B, S, dtype=torch.long), "resolved_over_steps": [], } return current_ids.cpu(), stats