sad / src /eval /block_sampler.py
haochengsama's picture
Add files using upload-large-folder tool
922bb4b verified
Raw
History Blame Contribute Delete
14.6 kB
"""
Block-wise Sampler for SAD.
Instead of token-wise adaptive decoding, this sampler operates on blocks of tokens.
Given a context length of 512 and block size of 8, we have 64 blocks.
Block-wise adaptive: Resolve entire blocks at once based on aggregate confidence.
Per-level confidence is computed via h × prototype inner products, matching
the approach used in SADSampler and SADBlockSampler.
"""
from typing import Dict, List, Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm.auto as tqdm
from .sampler import compute_confidence
class BlockWiseAdaptiveSampler(nn.Module):
"""
Block-wise adaptive skip/backoff sampler.
Divides sequence into blocks (e.g., 512 / 8 = 64 blocks).
Each iteration:
1. Run model forward → leaf_logits, h
2. Compute per-level confidence via h × prototype inner products
3. Aggregate to block-level confidence
4. Resolve entire blocks that pass threshold
Args:
model: SADModel
hierarchy: SoftAncestorHierarchy or HardAncestorHierarchy
tokenizer: HuggingFace tokenizer
block_size: number of tokens per block (default 8)
confidence_metric: "neg_entropy" or "max_prob"
thresholds: per-level thresholds [tau_leaf, tau_l1, ...]
block_agg: how to aggregate token confidences to block confidence:
"mean", "min", or "max"
"""
def __init__(
self,
model: nn.Module,
hierarchy,
tokenizer,
block_size: int = 8,
confidence_metric: str = "neg_entropy",
thresholds: Optional[List[float]] = None,
block_agg: str = "mean",
freeze_resolved_blocks: bool = True,
):
super().__init__()
self.model = model
self.hierarchy = hierarchy
self.tokenizer = tokenizer
self.block_size = block_size
self.confidence_metric = confidence_metric
self.block_agg = block_agg
self.freeze_resolved_blocks = freeze_resolved_blocks
num_levels = hierarchy.num_levels
if thresholds is None:
thresholds = [0.8] + [0.5] * (num_levels - 1)
self.thresholds = thresholds
assert block_size > 0, "block_size must be positive"
assert block_agg in ["mean", "min", "max"], f"Unknown block_agg: {block_agg}"
def _aggregate_block_confidence(
self, token_conf: torch.Tensor
) -> torch.Tensor:
"""
Aggregate token-level confidences to block-level.
Args:
token_conf: [B, S] per-token confidence
Returns:
block_conf: [B, num_blocks] per-block confidence
"""
B, S = token_conf.shape
block_size = self.block_size
num_blocks = math.ceil(S / block_size)
# Pad if necessary
pad_len = num_blocks * block_size - S
if pad_len > 0:
token_conf = F.pad(token_conf, (0, pad_len), value=0.0)
# Reshape to [B, num_blocks, block_size]
token_conf = token_conf.reshape(B, num_blocks, block_size)
# Aggregate
if self.block_agg == "mean":
return token_conf.mean(dim=-1) # [B, num_blocks]
elif self.block_agg == "min":
return token_conf.min(dim=-1).values
elif self.block_agg == "max":
return token_conf.max(dim=-1).values
else:
raise ValueError(f"Unknown block_agg: {self.block_agg}")
def _get_block_resolution_level(
self, block_conf: torch.Tensor
) -> torch.Tensor:
"""
Determine which level each block should resolve to.
Args:
block_conf: dict mapping level to [B, num_blocks] confidence
Returns:
resolve_level: [B, num_blocks] int, 0=leaf, 1=level1, etc., -1=unresolved
"""
B, num_blocks = block_conf[0].shape
device = block_conf[0].device
# Start with -1 (unresolved)
resolve_level = torch.full((B, num_blocks), -1, dtype=torch.long, device=device)
# Check from leaf (level 0) to coarsest
for level in range(len(self.thresholds)):
conf = block_conf[level] # [B, num_blocks]
tau = self.thresholds[level]
# Mark unresolved blocks that meet threshold
unresolved = (resolve_level == -1)
should_resolve = unresolved & (conf >= tau)
resolve_level[should_resolve] = level
return resolve_level
@torch.no_grad()
def generate(
self,
num_samples: int,
num_steps: int,
max_length: int = 512,
device=None,
show_progress: bool = True,
random_coarse_init: bool = False,
) -> Tuple[torch.Tensor, Dict]:
"""
Generate sequences using block-wise adaptive decoding.
Returns:
token_ids: [B, S] final tokens
stats: dict with block-level statistics
"""
if device is None:
device = next(self.model.parameters()).device
B, S = num_samples, max_length
block_size = self.block_size
num_blocks = math.ceil(S / block_size)
mask_id = self.tokenizer.mask_token_id
vocab_size = self.model.vocab_size
# Initialize
if random_coarse_init and self.hierarchy.num_levels >= 2:
K_top = self.hierarchy.level_sizes[-1]
offset = sum(self.hierarchy.level_sizes[:-1])
rand_coarse = torch.randint(0, K_top, (B, S), device=device) + offset
current_ids = rand_coarse
else:
current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device)
# Track block resolution
block_resolved = torch.zeros(B, num_blocks, dtype=torch.bool, device=device)
block_exit_levels = torch.full((B, num_blocks), -1, dtype=torch.long, device=device)
token_exit_levels = torch.full((B, S), self.hierarchy.num_levels - 1,
dtype=torch.long, device=device)
# Pre-compute block boundaries for efficiency
block_boundaries = [(i * block_size, min((i + 1) * block_size, S))
for i in range(num_blocks)]
ts = torch.linspace(1.0 - self.t_eps, self.t_eps, num_steps, device=device)
# Check if hierarchy is soft (needs embeddings)
is_soft = hasattr(self.hierarchy, 'prototypes') and \
any(p.requires_grad for p in self.hierarchy.parameters())
resolved_over_steps = []
for step_i in tqdm.trange(num_steps, desc="Block-wise SAD", disable=not show_progress):
t_val = ts[step_i]
t_batch = t_val.expand(B)
# Forward pass
leaf_logits, _ = self.model(input_ids=current_ids, t=t_batch)
leaf_logits[..., mask_id] = float('-inf')
p_leaf = leaf_logits.softmax(dim=-1) # [B, S, V]
# Project upward
if is_soft:
leaf_emb = self.model.get_leaf_embeddings()
assignments = self.hierarchy.get_all_assignments(leaf_emb)
else:
assignments = self.hierarchy.get_all_assignments()
p_levels = self.hierarchy.project_upward(p_leaf, assignments=assignments)
# p_levels[l-1] = p^(l): [B, S, K_l]
# Compute per-token confidence at each level
conf_leaf = compute_confidence(p_leaf, self.confidence_metric) # [B, S]
conf_levels = [
compute_confidence(p_l, self.confidence_metric)
for p_l in p_levels
] # List of [B, S]
# Aggregate to block-level confidence
block_conf = {
0: self._aggregate_block_confidence(conf_leaf),
}
for li, conf_l in enumerate(conf_levels, start=1):
block_conf[li] = self._aggregate_block_confidence(conf_l)
# Determine resolution level for each block
resolve_level = self._get_block_resolution_level(block_conf) # [B, num_blocks]
# Update blocks
new_ids = current_ids.clone()
for block_idx in range(num_blocks):
if block_resolved[:, block_idx].all():
continue
start, end = block_boundaries[block_idx]
level = resolve_level[:, block_idx] # [B]
for b in range(B):
lvl = level[b].item()
if lvl < 0:
continue # Not confident enough
# Resolve this block at level lvl
if lvl == 0:
# Leaf level: sample/greedily select tokens
block_p = p_leaf[b, start:end] # [block_size, V]
block_tokens = block_p.argmax(dim=-1) # [block_size]
new_ids[b, start:end] = block_tokens
token_exit_levels[b, start:end] = 0
else:
# Intermediate level: use projected distribution
block_p = p_levels[lvl - 1][b, start:end] # [block_size, K_l]
block_ancestors = block_p.argmax(dim=-1) # [block_size]
# Offset to extended vocab
offset = sum(self.hierarchy.level_sizes[:lvl])
new_ids[b, start:end] = block_ancestors + offset
token_exit_levels[b, start:end] = lvl
block_resolved[b, block_idx] = True
block_exit_levels[b, block_idx] = lvl
current_ids = new_ids
resolved_over_steps.append(block_resolved.float().mean().item())
if block_resolved.all():
if show_progress:
print(f"All blocks resolved at step {step_i + 1}")
break
# Final pass: force unresolved to leaf
unresolved_blocks = ~block_resolved
if unresolved_blocks.any():
leaf_logits, _ = self.model(
input_ids=current_ids,
t=torch.full((B,), self.t_eps, device=device),
)
leaf_logits[..., mask_id] = float('-inf')
final_tokens = leaf_logits.argmax(dim=-1)
for b in range(B):
for block_idx in range(num_blocks):
if not block_resolved[b, block_idx]:
start, end = block_boundaries[block_idx]
current_ids[b, start:end] = final_tokens[b, start:end]
token_exit_levels[b, start:end] = 0
block_exit_levels[b, block_idx] = 0
stats = {
"block_exit_levels": block_exit_levels.cpu(),
"token_exit_levels": token_exit_levels.cpu(),
"block_resolved_over_steps": resolved_over_steps,
"num_blocks": num_blocks,
"block_size": block_size,
}
return current_ids.cpu(), stats
class BlockWiseAdjacentSampler(nn.Module):
"""
Baseline block-wise sampler without adaptive skipping.
Decodes block by block in a fixed order (left-to-right or random).
"""
def __init__(
self,
model: nn.Module,
tokenizer,
block_size: int = 8,
t_eps: float = 1e-4,
):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.block_size = block_size
self.t_eps = t_eps
@torch.no_grad()
def generate(
self,
num_samples: int,
num_steps: int,
max_length: int = 512,
device=None,
show_progress: bool = True,
) -> Tuple[torch.Tensor, Dict]:
"""Simple block-wise decoding without adaptive skip."""
if device is None:
device = next(self.model.parameters()).device
B, S = num_samples, max_length
block_size = self.block_size
num_blocks = math.ceil(S / block_size)
mask_id = self.tokenizer.mask_token_id
current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device)
ts = torch.linspace(1.0 - self.t_eps, self.t_eps, num_steps, device=device)
# Assign steps per block
steps_per_block = max(1, num_steps // num_blocks)
for step_i in tqdm.trange(num_steps, desc="Block-adjacent", disable=not show_progress):
t_val = ts[step_i]
t_batch = t_val.expand(B)
leaf_logits, _ = self.model(input_ids=current_ids, t=t_batch)
leaf_logits[..., mask_id] = float('-inf')
# Determine which block to update
block_idx = min(step_i // steps_per_block, num_blocks - 1)
start = block_idx * block_size
end = min(start + block_size, S)
# Update only current block
block_logits = leaf_logits[:, start:end] # [B, block_size, V]
block_tokens = block_logits.argmax(dim=-1) # [B, block_size]
# Only update if still masked
is_masked = (current_ids[:, start:end] == mask_id)
current_ids[:, start:end] = torch.where(
is_masked, block_tokens, current_ids[:, start:end]
)
# Final fill
is_masked = (current_ids == mask_id)
if is_masked.any():
leaf_logits, _ = self.model(
input_ids=current_ids,
t=torch.full((B,), self.t_eps, device=device),
)
leaf_logits[..., mask_id] = float('-inf')
final_tokens = leaf_logits.argmax(dim=-1)
current_ids = torch.where(is_masked, final_tokens, current_ids)
stats = {
"block_exit_levels": torch.zeros(B, num_blocks, dtype=torch.long),
"token_exit_levels": torch.zeros(B, S, dtype=torch.long),
"num_blocks": num_blocks,
"block_size": block_size,
}
return current_ids.cpu(), stats