sad / src /eval /sampler.py
haochengsama's picture
Add files using upload-large-folder tool
922bb4b verified
Raw
History Blame Contribute Delete
22 kB
"""
Adaptive Skip / Backoff Sampler for SAD.
Two inference modes:
1. `AdjacentOnlySampler`: baseline – decode one level at a time (no skip).
2. `AdaptiveSkipSampler`: main method – token-wise adaptive exit.
Confidence metrics supported:
- "max_prob": max_k p_leaf[k]
- "neg_entropy": (H_max - H(p_leaf)) / H_max (normalized negative entropy)
State per position (AdaptiveSkipSampler):
- resolved: bool [B, S] – True if token is finalized
- current_ids: [B, S] – current token ids (may be coarse cluster ids)
- exit_levels: [B, S] – which level each token exited at (for logging)
APPROXIMATION NOTE:
The sampler currently re-runs the full model on all positions each step,
even for resolved tokens. Resolved token embeddings are still part of the
context (they serve as conditioning). This is correct but not optimal for
speed. A KV-cache variant that skips resolved positions is left as TODO.
"""
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm.auto as tqdm
def compute_confidence(
p_leaf: torch.Tensor,
metric: str = "neg_entropy",
) -> torch.Tensor:
"""
Compute per-position confidence score.
Args:
p_leaf: [B, S, V] probability distribution over leaf tokens
metric: "max_prob" or "neg_entropy"
Returns:
conf: [B, S] confidence scores in [0, 1]
"""
if metric == "max_prob":
return p_leaf.max(dim=-1).values # [B, S]
elif metric == "neg_entropy":
V = p_leaf.shape[-1]
eps = 1e-8
H = -(p_leaf * (p_leaf + eps).log()).sum(dim=-1) # [B, S]
H_max = torch.log(torch.tensor(float(V), device=p_leaf.device))
# Normalize: 0 = max entropy (least confident), 1 = zero entropy
conf = 1.0 - H / H_max.clamp(min=1e-8)
return conf.clamp(0.0, 1.0) # [B, S]
else:
raise ValueError(f"Unknown confidence metric: {metric}")
class SADSampler(nn.Module):
"""
Token-wise adaptive skip / backoff sampler.
At each denoising step:
1. Run model -> leaf logits + hidden state h
2. Compute per-level confidence via h × prototype inner products
3. For each position, find the finest level with confidence >= threshold
4. If leaf conf >= tau_0: finalize the token
5. Else: update state to the appropriate coarse representation
Args:
model: SADModel
hierarchy: SoftAncestorHierarchy or HardAncestorHierarchy
tokenizer: HuggingFace tokenizer
confidence_metric: "neg_entropy" or "max_prob"
thresholds: list of per-level thresholds tau_0..tau_{L-1}
tau_0 is for leaf, tau_1 for level-1, ...
Higher threshold = harder to finalize at that level.
freeze_resolved: if True, finalized tokens are kept frozen in context.
"""
def __init__(
self,
model: nn.Module,
hierarchy,
tokenizer,
confidence_metric: str = "neg_entropy",
thresholds: Optional[List[float]] = None,
freeze_resolved: bool = True,
):
super().__init__()
self.model = model
self.hierarchy = hierarchy
self.tokenizer = tokenizer
self.confidence_metric = confidence_metric
self.freeze_resolved = freeze_resolved
num_levels = hierarchy.num_levels
if thresholds is None:
# Default: leaf threshold 0.8, coarser levels 0.5
thresholds = [0.8] + [0.5] * (num_levels - 1)
assert len(thresholds) >= num_levels
self.thresholds = thresholds
@torch.no_grad()
def generate(
self,
num_samples: int,
num_steps: int,
max_length: int = 512,
device=None,
show_progress: bool = True,
random_coarse_init: bool = False,
) -> Tuple[torch.Tensor, Dict]:
"""
Generate `num_samples` sequences of length `max_length`.
Args:
num_samples: B
num_steps: number of denoising iterations
max_length: S
device: target device
show_progress: show tqdm bar
random_coarse_init: if True, initialize from random top-level state
instead of all-mask (diverse coarse init mode).
Returns:
token_ids: [B, S] final generated token ids
stats: dict with 'exit_levels' [B, S], 'resolved_over_steps' list
"""
if device is None:
device = next(self.model.parameters()).device
B, S = num_samples, max_length
mask_id = self.tokenizer.mask_token_id
num_levels = self.hierarchy.num_levels
vocab_size = self.model.vocab_size
# Offset table: cluster ids at level l are stored as vocab_size + sum(K_1..K_{l-1}) + k
# offsets[l] = start of level-l cluster ids in the extended id space
offsets = [0] # level 0: leaf ids start at 0
for l in range(1, num_levels):
offsets.append(offsets[-1] + self.hierarchy.level_sizes[l - 1])
# offsets[1] = vocab_size (first cluster level starts here)
# offsets[l] = vocab_size + K1 + ... + K_{l-1}
leaf_emb = self.model.get_leaf_embeddings() # [V, d]
mask_emb = leaf_emb[mask_id] # [d]
def ids_to_embeddings(ids: torch.Tensor) -> torch.Tensor:
"""
Convert [B, S] ids (leaf / cluster offset / mask) → [B, S, d] embeddings.
mask_id → mask_emb
leaf id (0..V-1) → leaf_emb[id] (excluding mask_id)
cluster id at lvl l → prototypes[l-1][id - offsets[l]]
"""
embs = torch.zeros(B, S, leaf_emb.shape[-1],
device=device, dtype=leaf_emb.dtype)
# mask positions (checked first to avoid overlap with leaf range)
mask_pos = (ids == mask_id)
if mask_pos.any():
embs[mask_pos] = mask_emb
# leaf positions
leaf_pos = (ids < vocab_size) & ~mask_pos
if leaf_pos.any():
embs[leaf_pos] = leaf_emb[ids[leaf_pos]]
# cluster levels
for l in range(1, num_levels):
lo = offsets[l]
hi = lo + self.hierarchy.level_sizes[l]
cluster_pos = (ids >= lo) & (ids < hi)
if cluster_pos.any():
proto = self.hierarchy.prototypes[l - 1] # [K_l, d]
embs[cluster_pos] = proto[ids[cluster_pos] - lo]
return embs
# ---- Initialize state: all MASK (level = num_levels) ----
current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device)
# current_level tracks the coarseness of each token's current state:
# num_levels = MASK
# 1..num_levels-1 = intermediate cluster level l
# 0 = finalized leaf (resolved)
current_level = torch.full((B, S), num_levels, dtype=torch.long, device=device)
resolved = torch.zeros(B, S, dtype=torch.bool, device=device)
exit_levels = torch.full((B, S), num_levels, dtype=torch.long, device=device)
resolved_over_steps = []
for step_i in tqdm.trange(num_steps, desc="sad", disable=not show_progress):
# ---- id → embedding → forward ----
input_embs = ids_to_embeddings(current_ids) # [B, S, d]
leaf_logits, h = self.model(input_embeddings=input_embs) # [B,S,V], [B,S,d]
leaf_logits[..., mask_id] = float('-inf')
p_leaf = leaf_logits.softmax(dim=-1) # [B, S, V]
# ---- Per-level confidence: h @ proto_l (cosine sim) ----
h_norm = F.normalize(h.float(), dim=-1) # [B, S, d]
conf_per_level = [] # conf_per_level[l-1]: [B, S] for level l
proto_norm_cache = []
for l in range(1, num_levels):
proto_norm = F.normalize(
self.hierarchy.prototypes[l - 1].float(), dim=-1) # [K_l, d]
sim = h_norm @ proto_norm.T # [B, S, K_l]
p_l = sim.softmax(dim=-1)
conf_per_level.append(compute_confidence(p_l, self.confidence_metric))
proto_norm_cache.append(proto_norm)
conf_leaf = compute_confidence(p_leaf, self.confidence_metric) # [B, S]
# ---- Update: finest first, stop at first hit, no backtracking ----
# A token at current_level L can only move to level l < L (finer).
active = ~resolved
new_ids = current_ids.clone()
updated = torch.zeros(B, S, dtype=torch.bool, device=device)
# Check leaf first (finest)
can_leaf = active & ~updated & (current_level > 0)
finalize = can_leaf & (conf_leaf >= self.thresholds[0])
if finalize.any():
new_ids[finalize] = p_leaf.argmax(dim=-1)[finalize]
current_level[finalize] = 0
resolved[finalize] = True
exit_levels[finalize] = 0
updated |= finalize
# Then check intermediate levels from finest (l=1) to coarsest (l=num_levels-1)
for l in range(1, num_levels):
# Only consider positions not yet updated this step,
# and whose current level is coarser than l (current_level > l)
can_move = active & ~updated & (current_level > l)
assign = can_move & (conf_per_level[l - 1] >= self.thresholds[l])
if assign.any():
cluster_ids = (
(h_norm[assign] @ proto_norm_cache[l - 1].T)
.argmax(dim=-1) + offsets[l]
)
new_ids[assign] = cluster_ids
current_level[assign] = l
updated |= assign
current_ids = new_ids
resolved_over_steps.append(resolved.float().mean().item())
if resolved.all():
if show_progress:
print(f"All tokens resolved at step {step_i + 1}")
break
# Final: force any unresolved to argmax leaf
unresolved = ~resolved
if unresolved.any():
input_embs = ids_to_embeddings(current_ids)
leaf_logits, _ = self.model(input_embeddings=input_embs)
leaf_logits[..., mask_id] = float('-inf')
current_ids[unresolved] = leaf_logits.argmax(dim=-1)[unresolved]
stats = {
"exit_levels": exit_levels.cpu(),
"resolved_over_steps": resolved_over_steps,
}
return current_ids.cpu(), stats
class SADBlockSampler(nn.Module):
"""
Block-by-block adaptive skip sampler for Block-AR inference.
Generates one block at a time in autoregressive order. Within each block
the same adaptive-skip logic as SADSampler applies (finest-first
confidence check via h × prototype inner products, no backtracking).
Previously resolved blocks are kept as clean causal context via
model.forward_causal().
Args:
model: SADModel
hierarchy: SoftAncestorHierarchy or HardAncestorHierarchy
tokenizer: HuggingFace tokenizer (needs mask_token_id)
confidence_metric: "neg_entropy" or "max_prob"
thresholds: per-level thresholds [tau_leaf, tau_l1, tau_l2, ...]
num_steps_per_block: denoising steps per block
"""
def __init__(
self,
model: nn.Module,
hierarchy,
tokenizer,
confidence_metric: str = "neg_entropy",
thresholds: Optional[List[float]] = None,
num_steps_per_block: int = 20,
):
super().__init__()
self.model = model
self.hierarchy = hierarchy
self.tokenizer = tokenizer
self.confidence_metric = confidence_metric
self.num_steps_per_block = num_steps_per_block
num_levels = hierarchy.num_levels
if thresholds is None:
thresholds = [0.8] + [0.5] * (num_levels - 1)
assert len(thresholds) >= num_levels
self.thresholds = thresholds
@torch.no_grad()
def generate(
self,
num_samples: int,
max_length: int = 512,
device=None,
show_progress: bool = True,
) -> Tuple[torch.Tensor, dict]:
"""
Generate `num_samples` sequences of length `max_length`.
Returns:
token_ids: [B, S]
stats: dict with 'exit_levels' [B, S]
"""
if device is None:
device = next(self.model.parameters()).device
block_size = self.model.block_size
assert max_length % block_size == 0, \
f"max_length ({max_length}) must be divisible by block_size ({block_size})"
B = num_samples
S = max_length
num_blocks = S // block_size
mask_id = self.tokenizer.mask_token_id
num_levels = self.hierarchy.num_levels
vocab_size = self.model.vocab_size
# Offsets for extended id space (same as AdaptiveSkipSampler)
offsets = [0]
for l in range(1, num_levels):
offsets.append(offsets[-1] + self.hierarchy.level_sizes[l - 1])
leaf_emb = self.model.get_leaf_embeddings() # [V, d]
mask_emb = leaf_emb[mask_id] # [d]
def ids_to_embeddings(ids: torch.Tensor) -> torch.Tensor:
"""[B, S] ids → [B, S, d] embeddings."""
out = torch.zeros(B, ids.shape[1], leaf_emb.shape[-1],
device=device, dtype=leaf_emb.dtype)
mask_pos = (ids == mask_id)
if mask_pos.any():
out[mask_pos] = mask_emb
leaf_pos = (ids < vocab_size) & ~mask_pos
if leaf_pos.any():
out[leaf_pos] = leaf_emb[ids[leaf_pos]]
for l in range(1, num_levels):
lo = offsets[l]
hi = lo + self.hierarchy.level_sizes[l]
cluster_pos = (ids >= lo) & (ids < hi)
if cluster_pos.any():
proto = self.hierarchy.prototypes[l - 1]
out[cluster_pos] = proto[ids[cluster_pos] - lo]
return out
# State for the full sequence
current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device)
exit_levels = torch.full((B, S), num_levels, dtype=torch.long, device=device)
block_iter = range(num_blocks)
if show_progress:
from tqdm.auto import tqdm as _tqdm
block_iter = _tqdm(block_iter, desc="sad_block")
for blk in block_iter:
blk_start = blk * block_size
blk_end = blk_start + block_size
# Block-level state
blk_current = current_ids[:, blk_start:blk_end].clone() # [B, bs]
blk_level = torch.full((B, block_size), num_levels, dtype=torch.long, device=device)
blk_resolved = torch.zeros(B, block_size, dtype=torch.bool, device=device)
for step_i in range(self.num_steps_per_block):
# Build full-sequence embeddings: resolved prefix + current block state
current_ids[:, blk_start:blk_end] = blk_current
input_embs = ids_to_embeddings(current_ids) # [B, S, d]
# Use forward_causal so the block attends to itself + earlier blocks
leaf_logits, h = self.model.forward_causal(input_embs) # [B, S, V], [B, S, d]
# Focus on current block only
blk_logits = leaf_logits[:, blk_start:blk_end, :] # [B, bs, V]
blk_h = h[:, blk_start:blk_end, :] # [B, bs, d]
blk_logits[..., mask_id] = float('-inf')
p_leaf = blk_logits.softmax(dim=-1) # [B, bs, V]
# Per-level confidence from h
h_norm = F.normalize(blk_h.float(), dim=-1) # [B, bs, d]
conf_per_level = []
proto_norm_cache = []
for l in range(1, num_levels):
proto_norm = F.normalize(
self.hierarchy.prototypes[l - 1].float(), dim=-1)
sim = h_norm @ proto_norm.T # [B, bs, K_l]
p_l = sim.softmax(dim=-1)
conf_per_level.append(
compute_confidence(p_l, self.confidence_metric))
proto_norm_cache.append(proto_norm)
conf_leaf = compute_confidence(p_leaf, self.confidence_metric) # [B, bs]
# Adaptive skip: finest first, no backtracking
active = ~blk_resolved
new_blk = blk_current.clone()
updated = torch.zeros(B, block_size, dtype=torch.bool, device=device)
# Leaf (finest)
can_leaf = active & ~updated & (blk_level > 0)
finalize = can_leaf & (conf_leaf >= self.thresholds[0])
if finalize.any():
new_blk[finalize] = p_leaf.argmax(dim=-1)[finalize]
blk_level[finalize] = 0
blk_resolved[finalize] = True
exit_levels[:, blk_start:blk_end][finalize] = 0
updated |= finalize
# Intermediate levels
for l in range(1, num_levels):
can_move = active & ~updated & (blk_level > l)
assign = can_move & (conf_per_level[l - 1] >= self.thresholds[l])
if assign.any():
cluster_ids = (
(h_norm[assign] @ proto_norm_cache[l - 1].T)
.argmax(dim=-1) + offsets[l]
)
new_blk[assign] = cluster_ids
blk_level[assign] = l
updated |= assign
blk_current = new_blk
if blk_resolved.all():
break
# Force-resolve any remaining tokens in the block
unresolved = ~blk_resolved
if unresolved.any():
current_ids[:, blk_start:blk_end] = blk_current
input_embs = ids_to_embeddings(current_ids)
blk_logits, _ = self.model.forward_causal(input_embs)
blk_logits = blk_logits[:, blk_start:blk_end, :]
blk_logits[..., mask_id] = float('-inf')
blk_current[unresolved] = blk_logits.argmax(dim=-1)[unresolved]
exit_levels[:, blk_start:blk_end][unresolved] = 0
# Commit resolved block to current_ids (leaf ids only)
current_ids[:, blk_start:blk_end] = blk_current
stats = {"exit_levels": exit_levels.cpu()}
return current_ids.cpu(), stats
class AdjacentOnlySampler(nn.Module):
"""
Baseline sampler: always decode level by level (no skipping).
Each iteration takes the model's argmax prediction. No adaptive exit.
Equivalent to MDLM-style single-step argmax decoding.
Args:
model: SADModel
tokenizer: HuggingFace tokenizer
t_eps: timestep epsilon
"""
def __init__(self, model: nn.Module, tokenizer, t_eps: float = 1e-4):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.t_eps = t_eps
@torch.no_grad()
def generate(
self,
num_samples: int,
num_steps: int,
max_length: int = 512,
device=None,
show_progress: bool = True,
) -> Tuple[torch.Tensor, Dict]:
if device is None:
device = next(self.model.parameters()).device
B, S = num_samples, max_length
mask_id = self.tokenizer.mask_token_id
leaf_emb = self.model.get_leaf_embeddings() # [V, d]
mask_emb = leaf_emb[mask_id] # [d]
def to_emb(ids: torch.Tensor) -> torch.Tensor:
# ids: [B, S], only leaf ids or mask_id
embs = leaf_emb[ids.clamp(0, leaf_emb.shape[0] - 1)]
mask_pos = (ids == mask_id)
if mask_pos.any():
embs[mask_pos] = mask_emb
return embs
current_ids = torch.full((B, S), mask_id, dtype=torch.long, device=device)
for step_i in tqdm.trange(num_steps, desc="Adjacent sampling", disable=not show_progress):
leaf_logits, _ = self.model(input_embeddings=to_emb(current_ids))
leaf_logits[..., mask_id] = float('-inf')
# MDLM-style: only update masked positions
is_masked = (current_ids == mask_id)
best_tokens = leaf_logits.argmax(dim=-1) # [B, S]
# Stochastically unmask some positions
unmask_prob = 1.0 / max(num_steps - step_i, 1)
should_unmask = torch.rand_like(current_ids.float()) < unmask_prob
update = is_masked & should_unmask
current_ids = torch.where(update, best_tokens, current_ids)
# Final pass: decode remaining masks
is_masked = (current_ids == mask_id)
if is_masked.any():
leaf_logits, _ = self.model(input_embeddings=to_emb(current_ids))
leaf_logits[..., mask_id] = float('-inf')
final_tokens = leaf_logits.argmax(dim=-1)
current_ids = torch.where(is_masked, final_tokens, current_ids)
stats = {
"exit_levels": torch.zeros(B, S, dtype=torch.long),
"resolved_over_steps": [],
}
return current_ids.cpu(), stats