| """ |
| Adaptive Skip / Backoff Sampler for SAD. |
| |
| Two inference modes: |
| 1. `AdjacentOnlySampler`: baseline – decode one level at a time (no skip). |
| 2. `AdaptiveSkipSampler`: main method – token-wise adaptive exit. |
| |
| Confidence metrics supported: |
| - "max_prob": max_k p_leaf[k] |
| - "neg_entropy": (H_max - H(p_leaf)) / H_max (normalized negative entropy) |
| |
| State per position (AdaptiveSkipSampler): |
| - resolved: bool [B, S] – True if token is finalized |
| - current_ids: [B, S] – current token ids (may be coarse cluster ids) |
| - exit_levels: [B, S] – which level each token exited at (for logging) |
| |
| APPROXIMATION NOTE: |
| The sampler currently re-runs the full model on all positions each step, |
| even for resolved tokens. Resolved token embeddings are still part of the |
| context (they serve as conditioning). This is correct but not optimal for |
| speed. A KV-cache variant that skips resolved positions is left as TODO. |
| """ |
|
|
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import tqdm.auto as tqdm |
|
|
|
|
| def compute_confidence( |
| p_leaf: torch.Tensor, |
| metric: str = "neg_entropy", |
| ) -> torch.Tensor: |
| """ |
| Compute per-position confidence score. |
| |
| Args: |
| p_leaf: [B, S, V] probability distribution over leaf tokens |
| metric: "max_prob" or "neg_entropy" |
| |
| Returns: |
| conf: [B, S] confidence scores in [0, 1] |
| """ |
| if metric == "max_prob": |
| return p_leaf.max(dim=-1).values |
| elif metric == "neg_entropy": |
| V = p_leaf.shape[-1] |
| eps = 1e-8 |
| H = -(p_leaf * (p_leaf + eps).log()).sum(dim=-1) |
| H_max = torch.log(torch.tensor(float(V), device=p_leaf.device)) |
| |
| conf = 1.0 - H / H_max.clamp(min=1e-8) |
| return conf.clamp(0.0, 1.0) |
| else: |
| raise ValueError(f"Unknown confidence metric: {metric}") |
|
|
|
|
| class SADSampler(nn.Module): |
| """ |
| Token-wise adaptive skip / backoff sampler. |
| |
| At each denoising step: |
| 1. Run model -> leaf logits + hidden state h |
| 2. Compute per-level confidence via h × prototype inner products |
| 3. For each position, find the finest level with confidence >= threshold |
| 4. If leaf conf >= tau_0: finalize the token |
| 5. Else: update state to the appropriate coarse representation |
| |
| Args: |
| model: SADModel |
| hierarchy: SoftAncestorHierarchy or HardAncestorHierarchy |
| tokenizer: HuggingFace tokenizer |
| confidence_metric: "neg_entropy" or "max_prob" |
| thresholds: list of per-level thresholds tau_0..tau_{L-1} |
| tau_0 is for leaf, tau_1 for level-1, ... |
| Higher threshold = harder to finalize at that level. |
| freeze_resolved: if True, finalized tokens are kept frozen in context. |
| """ |
|
|
| def __init__( |
| self, |
| model: nn.Module, |
| hierarchy, |
| tokenizer, |
| confidence_metric: str = "neg_entropy", |
| thresholds: Optional[List[float]] = None, |
| freeze_resolved: bool = True, |
| ): |
| super().__init__() |
| self.model = model |
| self.hierarchy = hierarchy |
| self.tokenizer = tokenizer |
| self.confidence_metric = confidence_metric |
| self.freeze_resolved = freeze_resolved |
|
|
| num_levels = hierarchy.num_levels |
| if thresholds is None: |
| |
| thresholds = [0.8] + [0.5] * (num_levels - 1) |
| assert len(thresholds) >= num_levels |
| self.thresholds = thresholds |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| num_samples: int, |
| num_steps: int, |
| max_length: int = 512, |
| device=None, |
| show_progress: bool = True, |
| random_coarse_init: bool = False, |
| ) -> Tuple[torch.Tensor, Dict]: |
| """ |
| Generate `num_samples` sequences of length `max_length`. |
| |
| Args: |
| num_samples: B |
| num_steps: number of denoising iterations |
| max_length: S |
| device: target device |
| show_progress: show tqdm bar |
| random_coarse_init: if True, initialize from random top-level state |
| instead of all-mask (diverse coarse init mode). |
| |
| Returns: |
| token_ids: [B, S] final generated token ids |
| stats: dict with 'exit_levels' [B, S], 'resolved_over_steps' list |
| """ |
|
|
| if device is None: |
| device = next(self.model.parameters()).device |
|
|
| B, S = num_samples, max_length |
| mask_id = self.tokenizer.mask_token_id |
| num_levels = self.hierarchy.num_levels |
| vocab_size = self.model.vocab_size |
|
|
| |
| |
| offsets = [0] |
| for l in range(1, num_levels): |
| offsets.append(offsets[-1] + self.hierarchy.level_sizes[l - 1]) |
| |
| |
|
|
| leaf_emb = self.model.get_leaf_embeddings() |
| mask_emb = leaf_emb[mask_id] |
|
|
| def ids_to_embeddings(ids: torch.Tensor) -> torch.Tensor: |
| """ |
| Convert [B, S] ids (leaf / cluster offset / mask) → [B, S, d] embeddings. |
| mask_id → mask_emb |
| leaf id (0..V-1) → leaf_emb[id] (excluding mask_id) |
| cluster id at lvl l → prototypes[l-1][id - offsets[l]] |
| """ |
| embs = torch.zeros(B, S, leaf_emb.shape[-1], |
| device=device, dtype=leaf_emb.dtype) |
| |
| mask_pos = (ids == mask_id) |
| if mask_pos.any(): |
| embs[mask_pos] = mask_emb |
| |
| leaf_pos = (ids < vocab_size) & ~mask_pos |
| if leaf_pos.any(): |
| embs[leaf_pos] = leaf_emb[ids[leaf_pos]] |
| |
| for l in range(1, num_levels): |
| lo = offsets[l] |
| hi = lo + self.hierarchy.level_sizes[l] |
| cluster_pos = (ids >= lo) & (ids < hi) |
| if cluster_pos.any(): |
| proto = self.hierarchy.prototypes[l - 1] |
| embs[cluster_pos] = proto[ids[cluster_pos] - lo] |
| return embs |
|
|
| |
| current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device) |
| |
| |
| |
| |
| current_level = torch.full((B, S), num_levels, dtype=torch.long, device=device) |
| resolved = torch.zeros(B, S, dtype=torch.bool, device=device) |
| exit_levels = torch.full((B, S), num_levels, dtype=torch.long, device=device) |
|
|
| resolved_over_steps = [] |
|
|
| for step_i in tqdm.trange(num_steps, desc="sad", disable=not show_progress): |
|
|
| |
| input_embs = ids_to_embeddings(current_ids) |
| leaf_logits, h = self.model(input_embeddings=input_embs) |
|
|
| leaf_logits[..., mask_id] = float('-inf') |
| p_leaf = leaf_logits.softmax(dim=-1) |
|
|
| |
| h_norm = F.normalize(h.float(), dim=-1) |
| conf_per_level = [] |
| proto_norm_cache = [] |
| for l in range(1, num_levels): |
| proto_norm = F.normalize( |
| self.hierarchy.prototypes[l - 1].float(), dim=-1) |
| sim = h_norm @ proto_norm.T |
| p_l = sim.softmax(dim=-1) |
| conf_per_level.append(compute_confidence(p_l, self.confidence_metric)) |
| proto_norm_cache.append(proto_norm) |
|
|
| conf_leaf = compute_confidence(p_leaf, self.confidence_metric) |
|
|
| |
| |
| active = ~resolved |
| new_ids = current_ids.clone() |
| updated = torch.zeros(B, S, dtype=torch.bool, device=device) |
|
|
| |
| can_leaf = active & ~updated & (current_level > 0) |
| finalize = can_leaf & (conf_leaf >= self.thresholds[0]) |
| if finalize.any(): |
| new_ids[finalize] = p_leaf.argmax(dim=-1)[finalize] |
| current_level[finalize] = 0 |
| resolved[finalize] = True |
| exit_levels[finalize] = 0 |
| updated |= finalize |
|
|
| |
| for l in range(1, num_levels): |
| |
| |
| can_move = active & ~updated & (current_level > l) |
| assign = can_move & (conf_per_level[l - 1] >= self.thresholds[l]) |
| if assign.any(): |
| cluster_ids = ( |
| (h_norm[assign] @ proto_norm_cache[l - 1].T) |
| .argmax(dim=-1) + offsets[l] |
| ) |
| new_ids[assign] = cluster_ids |
| current_level[assign] = l |
| updated |= assign |
|
|
| current_ids = new_ids |
| resolved_over_steps.append(resolved.float().mean().item()) |
|
|
| if resolved.all(): |
| if show_progress: |
| print(f"All tokens resolved at step {step_i + 1}") |
| break |
|
|
| |
| unresolved = ~resolved |
| if unresolved.any(): |
| input_embs = ids_to_embeddings(current_ids) |
| leaf_logits, _ = self.model(input_embeddings=input_embs) |
| leaf_logits[..., mask_id] = float('-inf') |
| current_ids[unresolved] = leaf_logits.argmax(dim=-1)[unresolved] |
|
|
| stats = { |
| "exit_levels": exit_levels.cpu(), |
| "resolved_over_steps": resolved_over_steps, |
| } |
| return current_ids.cpu(), stats |
|
|
|
|
| class SADBlockSampler(nn.Module): |
| """ |
| Block-by-block adaptive skip sampler for Block-AR inference. |
| |
| Generates one block at a time in autoregressive order. Within each block |
| the same adaptive-skip logic as SADSampler applies (finest-first |
| confidence check via h × prototype inner products, no backtracking). |
| Previously resolved blocks are kept as clean causal context via |
| model.forward_causal(). |
| |
| Args: |
| model: SADModel |
| hierarchy: SoftAncestorHierarchy or HardAncestorHierarchy |
| tokenizer: HuggingFace tokenizer (needs mask_token_id) |
| confidence_metric: "neg_entropy" or "max_prob" |
| thresholds: per-level thresholds [tau_leaf, tau_l1, tau_l2, ...] |
| num_steps_per_block: denoising steps per block |
| """ |
|
|
| def __init__( |
| self, |
| model: nn.Module, |
| hierarchy, |
| tokenizer, |
| confidence_metric: str = "neg_entropy", |
| thresholds: Optional[List[float]] = None, |
| num_steps_per_block: int = 20, |
| ): |
| super().__init__() |
| self.model = model |
| self.hierarchy = hierarchy |
| self.tokenizer = tokenizer |
| self.confidence_metric = confidence_metric |
| self.num_steps_per_block = num_steps_per_block |
|
|
| num_levels = hierarchy.num_levels |
| if thresholds is None: |
| thresholds = [0.8] + [0.5] * (num_levels - 1) |
| assert len(thresholds) >= num_levels |
| self.thresholds = thresholds |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| num_samples: int, |
| max_length: int = 512, |
| device=None, |
| show_progress: bool = True, |
| ) -> Tuple[torch.Tensor, dict]: |
| """ |
| Generate `num_samples` sequences of length `max_length`. |
| |
| Returns: |
| token_ids: [B, S] |
| stats: dict with 'exit_levels' [B, S] |
| """ |
| if device is None: |
| device = next(self.model.parameters()).device |
|
|
| block_size = self.model.block_size |
| assert max_length % block_size == 0, \ |
| f"max_length ({max_length}) must be divisible by block_size ({block_size})" |
|
|
| B = num_samples |
| S = max_length |
| num_blocks = S // block_size |
| mask_id = self.tokenizer.mask_token_id |
| num_levels = self.hierarchy.num_levels |
| vocab_size = self.model.vocab_size |
|
|
| |
| offsets = [0] |
| for l in range(1, num_levels): |
| offsets.append(offsets[-1] + self.hierarchy.level_sizes[l - 1]) |
|
|
| leaf_emb = self.model.get_leaf_embeddings() |
| mask_emb = leaf_emb[mask_id] |
|
|
| def ids_to_embeddings(ids: torch.Tensor) -> torch.Tensor: |
| """[B, S] ids → [B, S, d] embeddings.""" |
| out = torch.zeros(B, ids.shape[1], leaf_emb.shape[-1], |
| device=device, dtype=leaf_emb.dtype) |
| mask_pos = (ids == mask_id) |
| if mask_pos.any(): |
| out[mask_pos] = mask_emb |
| leaf_pos = (ids < vocab_size) & ~mask_pos |
| if leaf_pos.any(): |
| out[leaf_pos] = leaf_emb[ids[leaf_pos]] |
| for l in range(1, num_levels): |
| lo = offsets[l] |
| hi = lo + self.hierarchy.level_sizes[l] |
| cluster_pos = (ids >= lo) & (ids < hi) |
| if cluster_pos.any(): |
| proto = self.hierarchy.prototypes[l - 1] |
| out[cluster_pos] = proto[ids[cluster_pos] - lo] |
| return out |
|
|
| |
| current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device) |
| exit_levels = torch.full((B, S), num_levels, dtype=torch.long, device=device) |
|
|
| block_iter = range(num_blocks) |
| if show_progress: |
| from tqdm.auto import tqdm as _tqdm |
| block_iter = _tqdm(block_iter, desc="sad_block") |
|
|
| for blk in block_iter: |
| blk_start = blk * block_size |
| blk_end = blk_start + block_size |
|
|
| |
| blk_current = current_ids[:, blk_start:blk_end].clone() |
| blk_level = torch.full((B, block_size), num_levels, dtype=torch.long, device=device) |
| blk_resolved = torch.zeros(B, block_size, dtype=torch.bool, device=device) |
|
|
| for step_i in range(self.num_steps_per_block): |
| |
| current_ids[:, blk_start:blk_end] = blk_current |
| input_embs = ids_to_embeddings(current_ids) |
|
|
| |
| leaf_logits, h = self.model.forward_causal(input_embs) |
|
|
| |
| blk_logits = leaf_logits[:, blk_start:blk_end, :] |
| blk_h = h[:, blk_start:blk_end, :] |
|
|
| blk_logits[..., mask_id] = float('-inf') |
| p_leaf = blk_logits.softmax(dim=-1) |
|
|
| |
| h_norm = F.normalize(blk_h.float(), dim=-1) |
| conf_per_level = [] |
| proto_norm_cache = [] |
| for l in range(1, num_levels): |
| proto_norm = F.normalize( |
| self.hierarchy.prototypes[l - 1].float(), dim=-1) |
| sim = h_norm @ proto_norm.T |
| p_l = sim.softmax(dim=-1) |
| conf_per_level.append( |
| compute_confidence(p_l, self.confidence_metric)) |
| proto_norm_cache.append(proto_norm) |
|
|
| conf_leaf = compute_confidence(p_leaf, self.confidence_metric) |
|
|
| |
| active = ~blk_resolved |
| new_blk = blk_current.clone() |
| updated = torch.zeros(B, block_size, dtype=torch.bool, device=device) |
|
|
| |
| can_leaf = active & ~updated & (blk_level > 0) |
| finalize = can_leaf & (conf_leaf >= self.thresholds[0]) |
| if finalize.any(): |
| new_blk[finalize] = p_leaf.argmax(dim=-1)[finalize] |
| blk_level[finalize] = 0 |
| blk_resolved[finalize] = True |
| exit_levels[:, blk_start:blk_end][finalize] = 0 |
| updated |= finalize |
|
|
| |
| for l in range(1, num_levels): |
| can_move = active & ~updated & (blk_level > l) |
| assign = can_move & (conf_per_level[l - 1] >= self.thresholds[l]) |
| if assign.any(): |
| cluster_ids = ( |
| (h_norm[assign] @ proto_norm_cache[l - 1].T) |
| .argmax(dim=-1) + offsets[l] |
| ) |
| new_blk[assign] = cluster_ids |
| blk_level[assign] = l |
| updated |= assign |
|
|
| blk_current = new_blk |
|
|
| if blk_resolved.all(): |
| break |
|
|
| |
| unresolved = ~blk_resolved |
| if unresolved.any(): |
| current_ids[:, blk_start:blk_end] = blk_current |
| input_embs = ids_to_embeddings(current_ids) |
| blk_logits, _ = self.model.forward_causal(input_embs) |
| blk_logits = blk_logits[:, blk_start:blk_end, :] |
| blk_logits[..., mask_id] = float('-inf') |
| blk_current[unresolved] = blk_logits.argmax(dim=-1)[unresolved] |
| exit_levels[:, blk_start:blk_end][unresolved] = 0 |
|
|
| |
| current_ids[:, blk_start:blk_end] = blk_current |
|
|
| stats = {"exit_levels": exit_levels.cpu()} |
| return current_ids.cpu(), stats |
|
|
|
|
| class AdjacentOnlySampler(nn.Module): |
| """ |
| Baseline sampler: always decode level by level (no skipping). |
| |
| Each iteration takes the model's argmax prediction. No adaptive exit. |
| Equivalent to MDLM-style single-step argmax decoding. |
| |
| Args: |
| model: SADModel |
| tokenizer: HuggingFace tokenizer |
| t_eps: timestep epsilon |
| """ |
|
|
| def __init__(self, model: nn.Module, tokenizer, t_eps: float = 1e-4): |
| super().__init__() |
| self.model = model |
| self.tokenizer = tokenizer |
| self.t_eps = t_eps |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| num_samples: int, |
| num_steps: int, |
| max_length: int = 512, |
| device=None, |
| show_progress: bool = True, |
| ) -> Tuple[torch.Tensor, Dict]: |
| if device is None: |
| device = next(self.model.parameters()).device |
|
|
| B, S = num_samples, max_length |
| mask_id = self.tokenizer.mask_token_id |
|
|
| leaf_emb = self.model.get_leaf_embeddings() |
| mask_emb = leaf_emb[mask_id] |
|
|
| def to_emb(ids: torch.Tensor) -> torch.Tensor: |
| |
| embs = leaf_emb[ids.clamp(0, leaf_emb.shape[0] - 1)] |
| mask_pos = (ids == mask_id) |
| if mask_pos.any(): |
| embs[mask_pos] = mask_emb |
| return embs |
|
|
| current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device) |
|
|
| for step_i in tqdm.trange(num_steps, desc="Adjacent sampling", disable=not show_progress): |
| leaf_logits, _ = self.model(input_embeddings=to_emb(current_ids)) |
| leaf_logits[..., mask_id] = float('-inf') |
|
|
| |
| is_masked = (current_ids == mask_id) |
| best_tokens = leaf_logits.argmax(dim=-1) |
|
|
| |
| unmask_prob = 1.0 / max(num_steps - step_i, 1) |
| should_unmask = torch.rand_like(current_ids.float()) < unmask_prob |
| update = is_masked & should_unmask |
| current_ids = torch.where(update, best_tokens, current_ids) |
|
|
| |
| is_masked = (current_ids == mask_id) |
| if is_masked.any(): |
| leaf_logits, _ = self.model(input_embeddings=to_emb(current_ids)) |
| leaf_logits[..., mask_id] = float('-inf') |
| final_tokens = leaf_logits.argmax(dim=-1) |
| current_ids = torch.where(is_masked, final_tokens, current_ids) |
|
|
| stats = { |
| "exit_levels": torch.zeros(B, S, dtype=torch.long), |
| "resolved_over_steps": [], |
| } |
| return current_ids.cpu(), stats |
|
|