| """ |
| Block-wise Sampler for SAD. |
| |
| Instead of token-wise adaptive decoding, this sampler operates on blocks of tokens. |
| Given a context length of 512 and block size of 8, we have 64 blocks. |
| |
| Block-wise adaptive: Resolve entire blocks at once based on aggregate confidence. |
| Per-level confidence is computed via h × prototype inner products, matching |
| the approach used in SADSampler and SADBlockSampler. |
| """ |
|
|
| from typing import Dict, List, Optional, Tuple |
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import tqdm.auto as tqdm |
|
|
| from .sampler import compute_confidence |
|
|
|
|
| class BlockWiseAdaptiveSampler(nn.Module): |
| """ |
| Block-wise adaptive skip/backoff sampler. |
| |
| Divides sequence into blocks (e.g., 512 / 8 = 64 blocks). |
| Each iteration: |
| 1. Run model forward → leaf_logits, h |
| 2. Compute per-level confidence via h × prototype inner products |
| 3. Aggregate to block-level confidence |
| 4. Resolve entire blocks that pass threshold |
| |
| Args: |
| model: SADModel |
| hierarchy: SoftAncestorHierarchy or HardAncestorHierarchy |
| tokenizer: HuggingFace tokenizer |
| block_size: number of tokens per block (default 8) |
| confidence_metric: "neg_entropy" or "max_prob" |
| thresholds: per-level thresholds [tau_leaf, tau_l1, ...] |
| block_agg: how to aggregate token confidences to block confidence: |
| "mean", "min", or "max" |
| """ |
|
|
| def __init__( |
| self, |
| model: nn.Module, |
| hierarchy, |
| tokenizer, |
| block_size: int = 8, |
| confidence_metric: str = "neg_entropy", |
| thresholds: Optional[List[float]] = None, |
| block_agg: str = "mean", |
| freeze_resolved_blocks: bool = True, |
| ): |
| super().__init__() |
| self.model = model |
| self.hierarchy = hierarchy |
| self.tokenizer = tokenizer |
| self.block_size = block_size |
| self.confidence_metric = confidence_metric |
| self.block_agg = block_agg |
| self.freeze_resolved_blocks = freeze_resolved_blocks |
|
|
| num_levels = hierarchy.num_levels |
| if thresholds is None: |
| thresholds = [0.8] + [0.5] * (num_levels - 1) |
| self.thresholds = thresholds |
|
|
| assert block_size > 0, "block_size must be positive" |
| assert block_agg in ["mean", "min", "max"], f"Unknown block_agg: {block_agg}" |
|
|
| def _aggregate_block_confidence( |
| self, token_conf: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Aggregate token-level confidences to block-level. |
| |
| Args: |
| token_conf: [B, S] per-token confidence |
| |
| Returns: |
| block_conf: [B, num_blocks] per-block confidence |
| """ |
| B, S = token_conf.shape |
| block_size = self.block_size |
| num_blocks = math.ceil(S / block_size) |
| |
| |
| pad_len = num_blocks * block_size - S |
| if pad_len > 0: |
| token_conf = F.pad(token_conf, (0, pad_len), value=0.0) |
| |
| |
| token_conf = token_conf.reshape(B, num_blocks, block_size) |
| |
| |
| if self.block_agg == "mean": |
| return token_conf.mean(dim=-1) |
| elif self.block_agg == "min": |
| return token_conf.min(dim=-1).values |
| elif self.block_agg == "max": |
| return token_conf.max(dim=-1).values |
| else: |
| raise ValueError(f"Unknown block_agg: {self.block_agg}") |
|
|
| def _get_block_resolution_level( |
| self, block_conf: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Determine which level each block should resolve to. |
| |
| Args: |
| block_conf: dict mapping level to [B, num_blocks] confidence |
| |
| Returns: |
| resolve_level: [B, num_blocks] int, 0=leaf, 1=level1, etc., -1=unresolved |
| """ |
| B, num_blocks = block_conf[0].shape |
| device = block_conf[0].device |
| |
| |
| resolve_level = torch.full((B, num_blocks), -1, dtype=torch.long, device=device) |
| |
| |
| for level in range(len(self.thresholds)): |
| conf = block_conf[level] |
| tau = self.thresholds[level] |
| |
| unresolved = (resolve_level == -1) |
| should_resolve = unresolved & (conf >= tau) |
| resolve_level[should_resolve] = level |
| |
| return resolve_level |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| num_samples: int, |
| num_steps: int, |
| max_length: int = 512, |
| device=None, |
| show_progress: bool = True, |
| random_coarse_init: bool = False, |
| ) -> Tuple[torch.Tensor, Dict]: |
| """ |
| Generate sequences using block-wise adaptive decoding. |
| |
| Returns: |
| token_ids: [B, S] final tokens |
| stats: dict with block-level statistics |
| """ |
| if device is None: |
| device = next(self.model.parameters()).device |
|
|
| B, S = num_samples, max_length |
| block_size = self.block_size |
| num_blocks = math.ceil(S / block_size) |
| |
| mask_id = self.tokenizer.mask_token_id |
| vocab_size = self.model.vocab_size |
| |
| |
| if random_coarse_init and self.hierarchy.num_levels >= 2: |
| K_top = self.hierarchy.level_sizes[-1] |
| offset = sum(self.hierarchy.level_sizes[:-1]) |
| rand_coarse = torch.randint(0, K_top, (B, S), device=device) + offset |
| current_ids = rand_coarse |
| else: |
| current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device) |
| |
| |
| block_resolved = torch.zeros(B, num_blocks, dtype=torch.bool, device=device) |
| block_exit_levels = torch.full((B, num_blocks), -1, dtype=torch.long, device=device) |
| token_exit_levels = torch.full((B, S), self.hierarchy.num_levels - 1, |
| dtype=torch.long, device=device) |
| |
| |
| block_boundaries = [(i * block_size, min((i + 1) * block_size, S)) |
| for i in range(num_blocks)] |
| |
| ts = torch.linspace(1.0 - self.t_eps, self.t_eps, num_steps, device=device) |
| |
| |
| is_soft = hasattr(self.hierarchy, 'prototypes') and \ |
| any(p.requires_grad for p in self.hierarchy.parameters()) |
| |
| resolved_over_steps = [] |
| |
| for step_i in tqdm.trange(num_steps, desc="Block-wise SAD", disable=not show_progress): |
| t_val = ts[step_i] |
| t_batch = t_val.expand(B) |
| |
| |
| leaf_logits, _ = self.model(input_ids=current_ids, t=t_batch) |
| leaf_logits[..., mask_id] = float('-inf') |
| p_leaf = leaf_logits.softmax(dim=-1) |
| |
| |
| if is_soft: |
| leaf_emb = self.model.get_leaf_embeddings() |
| assignments = self.hierarchy.get_all_assignments(leaf_emb) |
| else: |
| assignments = self.hierarchy.get_all_assignments() |
| |
| p_levels = self.hierarchy.project_upward(p_leaf, assignments=assignments) |
| |
| |
| |
| conf_leaf = compute_confidence(p_leaf, self.confidence_metric) |
| conf_levels = [ |
| compute_confidence(p_l, self.confidence_metric) |
| for p_l in p_levels |
| ] |
| |
| |
| block_conf = { |
| 0: self._aggregate_block_confidence(conf_leaf), |
| } |
| for li, conf_l in enumerate(conf_levels, start=1): |
| block_conf[li] = self._aggregate_block_confidence(conf_l) |
| |
| |
| resolve_level = self._get_block_resolution_level(block_conf) |
| |
| |
| new_ids = current_ids.clone() |
| |
| for block_idx in range(num_blocks): |
| if block_resolved[:, block_idx].all(): |
| continue |
| |
| start, end = block_boundaries[block_idx] |
| level = resolve_level[:, block_idx] |
| |
| for b in range(B): |
| lvl = level[b].item() |
| if lvl < 0: |
| continue |
| |
| |
| if lvl == 0: |
| |
| block_p = p_leaf[b, start:end] |
| block_tokens = block_p.argmax(dim=-1) |
| new_ids[b, start:end] = block_tokens |
| token_exit_levels[b, start:end] = 0 |
| else: |
| |
| block_p = p_levels[lvl - 1][b, start:end] |
| block_ancestors = block_p.argmax(dim=-1) |
| |
| |
| offset = sum(self.hierarchy.level_sizes[:lvl]) |
| new_ids[b, start:end] = block_ancestors + offset |
| token_exit_levels[b, start:end] = lvl |
| |
| block_resolved[b, block_idx] = True |
| block_exit_levels[b, block_idx] = lvl |
| |
| current_ids = new_ids |
| resolved_over_steps.append(block_resolved.float().mean().item()) |
| |
| if block_resolved.all(): |
| if show_progress: |
| print(f"All blocks resolved at step {step_i + 1}") |
| break |
| |
| |
| unresolved_blocks = ~block_resolved |
| if unresolved_blocks.any(): |
| leaf_logits, _ = self.model( |
| input_ids=current_ids, |
| t=torch.full((B,), self.t_eps, device=device), |
| ) |
| leaf_logits[..., mask_id] = float('-inf') |
| final_tokens = leaf_logits.argmax(dim=-1) |
| |
| for b in range(B): |
| for block_idx in range(num_blocks): |
| if not block_resolved[b, block_idx]: |
| start, end = block_boundaries[block_idx] |
| current_ids[b, start:end] = final_tokens[b, start:end] |
| token_exit_levels[b, start:end] = 0 |
| block_exit_levels[b, block_idx] = 0 |
| |
| stats = { |
| "block_exit_levels": block_exit_levels.cpu(), |
| "token_exit_levels": token_exit_levels.cpu(), |
| "block_resolved_over_steps": resolved_over_steps, |
| "num_blocks": num_blocks, |
| "block_size": block_size, |
| } |
| return current_ids.cpu(), stats |
|
|
|
|
| class BlockWiseAdjacentSampler(nn.Module): |
| """ |
| Baseline block-wise sampler without adaptive skipping. |
| Decodes block by block in a fixed order (left-to-right or random). |
| """ |
| |
| def __init__( |
| self, |
| model: nn.Module, |
| tokenizer, |
| block_size: int = 8, |
| t_eps: float = 1e-4, |
| ): |
| super().__init__() |
| self.model = model |
| self.tokenizer = tokenizer |
| self.block_size = block_size |
| self.t_eps = t_eps |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| num_samples: int, |
| num_steps: int, |
| max_length: int = 512, |
| device=None, |
| show_progress: bool = True, |
| ) -> Tuple[torch.Tensor, Dict]: |
| """Simple block-wise decoding without adaptive skip.""" |
| if device is None: |
| device = next(self.model.parameters()).device |
| |
| B, S = num_samples, max_length |
| block_size = self.block_size |
| num_blocks = math.ceil(S / block_size) |
| mask_id = self.tokenizer.mask_token_id |
| |
| current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device) |
| ts = torch.linspace(1.0 - self.t_eps, self.t_eps, num_steps, device=device) |
| |
| |
| steps_per_block = max(1, num_steps // num_blocks) |
| |
| for step_i in tqdm.trange(num_steps, desc="Block-adjacent", disable=not show_progress): |
| t_val = ts[step_i] |
| t_batch = t_val.expand(B) |
| |
| leaf_logits, _ = self.model(input_ids=current_ids, t=t_batch) |
| leaf_logits[..., mask_id] = float('-inf') |
| |
| |
| block_idx = min(step_i // steps_per_block, num_blocks - 1) |
| start = block_idx * block_size |
| end = min(start + block_size, S) |
| |
| |
| block_logits = leaf_logits[:, start:end] |
| block_tokens = block_logits.argmax(dim=-1) |
| |
| |
| is_masked = (current_ids[:, start:end] == mask_id) |
| current_ids[:, start:end] = torch.where( |
| is_masked, block_tokens, current_ids[:, start:end] |
| ) |
| |
| |
| is_masked = (current_ids == mask_id) |
| if is_masked.any(): |
| leaf_logits, _ = self.model( |
| input_ids=current_ids, |
| t=torch.full((B,), self.t_eps, device=device), |
| ) |
| leaf_logits[..., mask_id] = float('-inf') |
| final_tokens = leaf_logits.argmax(dim=-1) |
| current_ids = torch.where(is_masked, final_tokens, current_ids) |
| |
| stats = { |
| "block_exit_levels": torch.zeros(B, num_blocks, dtype=torch.long), |
| "token_exit_levels": torch.zeros(B, S, dtype=torch.long), |
| "num_blocks": num_blocks, |
| "block_size": block_size, |
| } |
| return current_ids.cpu(), stats |
|
|