""" Block-wise Sampler for SAD. Instead of token-wise adaptive decoding, this sampler operates on blocks of tokens. Given a context length of 512 and block size of 8, we have 64 blocks. Block-wise adaptive: Resolve entire blocks at once based on aggregate confidence. Per-level confidence is computed via h × prototype inner products, matching the approach used in SADSampler and SADBlockSampler. """ from typing import Dict, List, Optional, Tuple import math import torch import torch.nn as nn import torch.nn.functional as F import tqdm.auto as tqdm from .sampler import compute_confidence class BlockWiseAdaptiveSampler(nn.Module): """ Block-wise adaptive skip/backoff sampler. Divides sequence into blocks (e.g., 512 / 8 = 64 blocks). Each iteration: 1. Run model forward → leaf_logits, h 2. Compute per-level confidence via h × prototype inner products 3. Aggregate to block-level confidence 4. Resolve entire blocks that pass threshold Args: model: SADModel hierarchy: SoftAncestorHierarchy or HardAncestorHierarchy tokenizer: HuggingFace tokenizer block_size: number of tokens per block (default 8) confidence_metric: "neg_entropy" or "max_prob" thresholds: per-level thresholds [tau_leaf, tau_l1, ...] block_agg: how to aggregate token confidences to block confidence: "mean", "min", or "max" """ def __init__( self, model: nn.Module, hierarchy, tokenizer, block_size: int = 8, confidence_metric: str = "neg_entropy", thresholds: Optional[List[float]] = None, block_agg: str = "mean", freeze_resolved_blocks: bool = True, ): super().__init__() self.model = model self.hierarchy = hierarchy self.tokenizer = tokenizer self.block_size = block_size self.confidence_metric = confidence_metric self.block_agg = block_agg self.freeze_resolved_blocks = freeze_resolved_blocks num_levels = hierarchy.num_levels if thresholds is None: thresholds = [0.8] + [0.5] * (num_levels - 1) self.thresholds = thresholds assert block_size > 0, "block_size must be positive" assert block_agg in ["mean", "min", "max"], f"Unknown block_agg: {block_agg}" def _aggregate_block_confidence( self, token_conf: torch.Tensor ) -> torch.Tensor: """ Aggregate token-level confidences to block-level. Args: token_conf: [B, S] per-token confidence Returns: block_conf: [B, num_blocks] per-block confidence """ B, S = token_conf.shape block_size = self.block_size num_blocks = math.ceil(S / block_size) # Pad if necessary pad_len = num_blocks * block_size - S if pad_len > 0: token_conf = F.pad(token_conf, (0, pad_len), value=0.0) # Reshape to [B, num_blocks, block_size] token_conf = token_conf.reshape(B, num_blocks, block_size) # Aggregate if self.block_agg == "mean": return token_conf.mean(dim=-1) # [B, num_blocks] elif self.block_agg == "min": return token_conf.min(dim=-1).values elif self.block_agg == "max": return token_conf.max(dim=-1).values else: raise ValueError(f"Unknown block_agg: {self.block_agg}") def _get_block_resolution_level( self, block_conf: torch.Tensor ) -> torch.Tensor: """ Determine which level each block should resolve to. Args: block_conf: dict mapping level to [B, num_blocks] confidence Returns: resolve_level: [B, num_blocks] int, 0=leaf, 1=level1, etc., -1=unresolved """ B, num_blocks = block_conf[0].shape device = block_conf[0].device # Start with -1 (unresolved) resolve_level = torch.full((B, num_blocks), -1, dtype=torch.long, device=device) # Check from leaf (level 0) to coarsest for level in range(len(self.thresholds)): conf = block_conf[level] # [B, num_blocks] tau = self.thresholds[level] # Mark unresolved blocks that meet threshold unresolved = (resolve_level == -1) should_resolve = unresolved & (conf >= tau) resolve_level[should_resolve] = level return resolve_level @torch.no_grad() def generate( self, num_samples: int, num_steps: int, max_length: int = 512, device=None, show_progress: bool = True, random_coarse_init: bool = False, ) -> Tuple[torch.Tensor, Dict]: """ Generate sequences using block-wise adaptive decoding. Returns: token_ids: [B, S] final tokens stats: dict with block-level statistics """ if device is None: device = next(self.model.parameters()).device B, S = num_samples, max_length block_size = self.block_size num_blocks = math.ceil(S / block_size) mask_id = self.tokenizer.mask_token_id vocab_size = self.model.vocab_size # Initialize if random_coarse_init and self.hierarchy.num_levels >= 2: K_top = self.hierarchy.level_sizes[-1] offset = sum(self.hierarchy.level_sizes[:-1]) rand_coarse = torch.randint(0, K_top, (B, S), device=device) + offset current_ids = rand_coarse else: current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device) # Track block resolution block_resolved = torch.zeros(B, num_blocks, dtype=torch.bool, device=device) block_exit_levels = torch.full((B, num_blocks), -1, dtype=torch.long, device=device) token_exit_levels = torch.full((B, S), self.hierarchy.num_levels - 1, dtype=torch.long, device=device) # Pre-compute block boundaries for efficiency block_boundaries = [(i * block_size, min((i + 1) * block_size, S)) for i in range(num_blocks)] ts = torch.linspace(1.0 - self.t_eps, self.t_eps, num_steps, device=device) # Check if hierarchy is soft (needs embeddings) is_soft = hasattr(self.hierarchy, 'prototypes') and \ any(p.requires_grad for p in self.hierarchy.parameters()) resolved_over_steps = [] for step_i in tqdm.trange(num_steps, desc="Block-wise SAD", disable=not show_progress): t_val = ts[step_i] t_batch = t_val.expand(B) # Forward pass leaf_logits, _ = self.model(input_ids=current_ids, t=t_batch) leaf_logits[..., mask_id] = float('-inf') p_leaf = leaf_logits.softmax(dim=-1) # [B, S, V] # Project upward if is_soft: leaf_emb = self.model.get_leaf_embeddings() assignments = self.hierarchy.get_all_assignments(leaf_emb) else: assignments = self.hierarchy.get_all_assignments() p_levels = self.hierarchy.project_upward(p_leaf, assignments=assignments) # p_levels[l-1] = p^(l): [B, S, K_l] # Compute per-token confidence at each level conf_leaf = compute_confidence(p_leaf, self.confidence_metric) # [B, S] conf_levels = [ compute_confidence(p_l, self.confidence_metric) for p_l in p_levels ] # List of [B, S] # Aggregate to block-level confidence block_conf = { 0: self._aggregate_block_confidence(conf_leaf), } for li, conf_l in enumerate(conf_levels, start=1): block_conf[li] = self._aggregate_block_confidence(conf_l) # Determine resolution level for each block resolve_level = self._get_block_resolution_level(block_conf) # [B, num_blocks] # Update blocks new_ids = current_ids.clone() for block_idx in range(num_blocks): if block_resolved[:, block_idx].all(): continue start, end = block_boundaries[block_idx] level = resolve_level[:, block_idx] # [B] for b in range(B): lvl = level[b].item() if lvl < 0: continue # Not confident enough # Resolve this block at level lvl if lvl == 0: # Leaf level: sample/greedily select tokens block_p = p_leaf[b, start:end] # [block_size, V] block_tokens = block_p.argmax(dim=-1) # [block_size] new_ids[b, start:end] = block_tokens token_exit_levels[b, start:end] = 0 else: # Intermediate level: use projected distribution block_p = p_levels[lvl - 1][b, start:end] # [block_size, K_l] block_ancestors = block_p.argmax(dim=-1) # [block_size] # Offset to extended vocab offset = sum(self.hierarchy.level_sizes[:lvl]) new_ids[b, start:end] = block_ancestors + offset token_exit_levels[b, start:end] = lvl block_resolved[b, block_idx] = True block_exit_levels[b, block_idx] = lvl current_ids = new_ids resolved_over_steps.append(block_resolved.float().mean().item()) if block_resolved.all(): if show_progress: print(f"All blocks resolved at step {step_i + 1}") break # Final pass: force unresolved to leaf unresolved_blocks = ~block_resolved if unresolved_blocks.any(): leaf_logits, _ = self.model( input_ids=current_ids, t=torch.full((B,), self.t_eps, device=device), ) leaf_logits[..., mask_id] = float('-inf') final_tokens = leaf_logits.argmax(dim=-1) for b in range(B): for block_idx in range(num_blocks): if not block_resolved[b, block_idx]: start, end = block_boundaries[block_idx] current_ids[b, start:end] = final_tokens[b, start:end] token_exit_levels[b, start:end] = 0 block_exit_levels[b, block_idx] = 0 stats = { "block_exit_levels": block_exit_levels.cpu(), "token_exit_levels": token_exit_levels.cpu(), "block_resolved_over_steps": resolved_over_steps, "num_blocks": num_blocks, "block_size": block_size, } return current_ids.cpu(), stats class BlockWiseAdjacentSampler(nn.Module): """ Baseline block-wise sampler without adaptive skipping. Decodes block by block in a fixed order (left-to-right or random). """ def __init__( self, model: nn.Module, tokenizer, block_size: int = 8, t_eps: float = 1e-4, ): super().__init__() self.model = model self.tokenizer = tokenizer self.block_size = block_size self.t_eps = t_eps @torch.no_grad() def generate( self, num_samples: int, num_steps: int, max_length: int = 512, device=None, show_progress: bool = True, ) -> Tuple[torch.Tensor, Dict]: """Simple block-wise decoding without adaptive skip.""" if device is None: device = next(self.model.parameters()).device B, S = num_samples, max_length block_size = self.block_size num_blocks = math.ceil(S / block_size) mask_id = self.tokenizer.mask_token_id current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device) ts = torch.linspace(1.0 - self.t_eps, self.t_eps, num_steps, device=device) # Assign steps per block steps_per_block = max(1, num_steps // num_blocks) for step_i in tqdm.trange(num_steps, desc="Block-adjacent", disable=not show_progress): t_val = ts[step_i] t_batch = t_val.expand(B) leaf_logits, _ = self.model(input_ids=current_ids, t=t_batch) leaf_logits[..., mask_id] = float('-inf') # Determine which block to update block_idx = min(step_i // steps_per_block, num_blocks - 1) start = block_idx * block_size end = min(start + block_size, S) # Update only current block block_logits = leaf_logits[:, start:end] # [B, block_size, V] block_tokens = block_logits.argmax(dim=-1) # [B, block_size] # Only update if still masked is_masked = (current_ids[:, start:end] == mask_id) current_ids[:, start:end] = torch.where( is_masked, block_tokens, current_ids[:, start:end] ) # Final fill is_masked = (current_ids == mask_id) if is_masked.any(): leaf_logits, _ = self.model( input_ids=current_ids, t=torch.full((B,), self.t_eps, device=device), ) leaf_logits[..., mask_id] = float('-inf') final_tokens = leaf_logits.argmax(dim=-1) current_ids = torch.where(is_masked, final_tokens, current_ids) stats = { "block_exit_levels": torch.zeros(B, num_blocks, dtype=torch.long), "token_exit_levels": torch.zeros(B, S, dtype=torch.long), "num_blocks": num_blocks, "block_size": block_size, } return current_ids.cpu(), stats