sad / scripts /inference_sad.py
haochengsama's picture
Add files using upload-large-folder tool
8b0aeb2 verified
Raw
History Blame Contribute Delete
39.9 kB
#!/usr/bin/env python3
"""
inference_sad.py – Block-wise hierarchical diffusion sampling from a trained
SADModel.
Generation proceeds block by block left-to-right. Within each block, a small
random subset of non-leaf positions is advanced each round to some strictly
finer level in the hierarchy
mask (level K+1) > ancestors (K, …, 1) > leaf (level 0)
A transition may jump any number of levels (e.g. mask β†’ leaf directly, or
ancestor l β†’ ancestor l' with l' < l, or ancestor β†’ leaf) as long as the new
level is strictly finer than the current one β€” never stay, never revert.
Rounds repeat until every position in the block is leaf; then the next block
begins.
Each denoising round:
1. One forward pass on the current block (K/V cache holds earlier blocks).
2. Softmax the leaf logits and project through the fixed LUT
(`AncestorTable.projection_matrix`) into every strictly-finer ancestor
level; max over each distribution gives per-level confidence (used to
rank candidate levels). For ancestor levels the conf is multiplied by
a per-level scalar λ_l ∈ [0, 1] before the cross-level comparison
(smaller Ξ»_l biases the schedule away from that ancestor level β€”
Ξ»_l = 0 disables it; the default Ξ» = 1 reproduces the original
behavior). Leaf (l=0) is never scaled. The target id is then produced
per-level:
- leaf level (l=0): argmax over the leaf distribution (deterministic)
- ancestor level (lβ‰₯1): multinomial sampling from the cluster dist. (stochastic)
Cross-level confidence is always computed from the original (temperature=1)
softmax so that leaf and ancestor probabilities are comparable.
3. Randomly pick `positions_per_step` non-leaf positions per sample and
transition each to its best strictly-finer level.
Finalized blocks' K/V are cached so forwards only recompute the current block.
Usage:
python scripts/inference_sad.py \\
--config configs/sad_owt.yaml \\
--checkpoint outputs/sad/latest.pt \\
--num_samples 4
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1] # sad/
from typing import Optional
import torch
import torch.nn.functional as F
import yaml
sys.path.insert(0, str(ROOT))
from src.models.sad_model import SADModel
from src.models.dit_components import apply_rotary_pos_emb, modulate_fused
from src.diffusion.ancestor_table import AncestorTable
from src.data import build_owt_dataloader
from einops import rearrange
# ─────────────────────────────────────────────────────────────────────────────
# Sampler
# ─────────────────────────────────────────────────────────────────────────────
class BlockDiffusionSampler:
"""
Block-wise hierarchical diffusion sampler for SADModel.
State per position is (level, value):
level = 0 β†’ leaf token; value = token id
level ∈ [1, K] β†’ ancestor at level l; value = cluster id in K_l
level = K + 1 β†’ mask
Per-block denoising loop (random position selection, strict-descent schedule):
Until every position in the block is leaf:
1. Forward pass on the current block (cache holds earlier blocks).
2. Vectorized over all block positions, project the leaf softmax
through the LUT:
leaf target (l=0): prob = softmax(logits)
ancestor target (lβ‰₯1): prob = softmax(logits) @ W_l [V, K_l]
Each candidate level contributes (conf, id): conf is the max-prob
(used only to compare levels). The id is argmax if the level is
leaf (l=0) and a multinomial draw if the level is ancestor (lβ‰₯1)
β€” so only the final landing in the leaf layer is deterministic,
while intermediate ancestor steps are stochastic. Only levels
strictly finer than the position's current level are eligible β€”
so mask β†’ leaf (skipping every ancestor) is a legal transition,
as is any multi-level jump. The eligible level with the highest
confidence wins.
3. Randomly pick `positions_per_step` non-leaf positions per sample
and apply the selected transition at those positions only.
"""
def __init__(
self,
model: SADModel,
ancestor_table: AncestorTable,
tokenizer,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
level_lambdas: Optional[list] = None,
leaf_temperature: float = 1.0,
):
"""
level_lambdas: length-K list of floats in [0, 1]. Ξ»_l (for ancestor
level l = 1..K) multiplies that level's max-prob conf before the
cross-level argmax that picks the winning target. Leaf (l=0) is
never scaled. None β†’ all ones (original behavior).
leaf_temperature: temperature applied to leaf logits before softmax.
Values < 1.0 sharpen the leaf distribution (higher confidence),
which is then used for both leaf sampling and ancestor projection.
Default 1.0 (no temperature scaling).
"""
self.model = model
self.ancestor_table = ancestor_table
self.tokenizer = tokenizer
self.device = device
self.dtype = dtype
self.leaf_temperature = float(leaf_temperature)
self.block_size: int = model.block_size
self.max_seq_len: int = model.max_seq_len
self.vocab_size: int = model.vocab_size
self.mask_id: int = tokenizer.mask_token_id
assert self.mask_id is not None, "tokenizer must have mask_token_id"
self.K: int = ancestor_table.num_levels # number of ancestor levels
self.mask_level: int = self.K + 1
if level_lambdas is None:
level_lambdas = [1.0] * self.K
assert len(level_lambdas) == self.K, (
f"level_lambdas must have length K={self.K}, got {len(level_lambdas)}"
)
for x in level_lambdas:
assert 0.0 <= float(x) <= 1.0, f"each Ξ» must be in [0, 1], got {x}"
# 1-indexed: self.level_lambdas[l] is λ_l for ancestor level l ∈ [1, K]
self.level_lambdas = [None] + [float(x) for x in level_lambdas]
# Leaf embedding table (tied with output head β€” read-only view).
self.leaf_emb = model.get_leaf_embeddings().to(device=device, dtype=dtype).detach()
self.mask_emb = self.leaf_emb[self.mask_id] # [d]
# Ancestor embeddings per level: fed into the model, so keep them in
# self.dtype to match model weights.
self.anc_embs = [None] + [
ancestor_table.ancestor_embeddings(l).to(device=device, dtype=dtype).detach()
for l in range(1, self.K + 1)
]
# LUT projection matrices W_l: used only on the scoring side (fp32).
# Fixed buffers, no grad, so fp32 storage is cheap.
self.W = [None] + [
ancestor_table.projection_matrix(l).to(device=device, dtype=torch.float32).detach()
for l in range(1, self.K + 1)
]
# ───────────────────────────────────────────────────────────────────────
def _build_mixed_embeddings(
self, level_ids: torch.Tensor, value_ids: torch.Tensor,
) -> torch.Tensor:
"""
Build [B, S, d] input embeddings from per-position (level, value).
Mirrors NoisyStateBuilder.build_noisy_embeddings so inference-time
inputs match the training distribution.
"""
B, S = level_ids.shape
d = self.leaf_emb.shape[-1]
embs = torch.empty(B, S, d, device=self.device, dtype=self.dtype)
# leaf (level 0) β€” leaf_emb[value]
m0 = (level_ids == 0)
if m0.any():
embs[m0] = self.leaf_emb[value_ids[m0]]
# mask (level K+1) β€” leaf_emb[mask_id]
mM = (level_ids == self.mask_level)
if mM.any():
embs[mM] = self.mask_emb
# ancestor levels 1..K β€” anc_embs[l][value]
for l in range(1, self.K + 1):
ml = (level_ids == l)
if ml.any():
embs[ml] = self.anc_embs[l][value_ids[ml]]
return embs
# ───────────────────────────────────────────────────────────────────────
# KV-cache–aware forward. The key observation: under the block-causal mask,
# the K/V produced at positions in finalized (leaf) earlier blocks are
# deterministic and never change. So we compute them once per block and
# reuse them across all denoising rounds of the current block.
#
# This method inlines DDiTBlockWithMask.forward so we can (a) accept a K/V
# prefix cache, (b) avoid recomputing Q/K/V for earlier blocks. When
# k_prefix is None it also serves as an uncached single-block pass (used
# for prompt blocks and the final K/V capture).
# ───────────────────────────────────────────────────────────────────────
def _run_layer_cached(
self,
layer_idx: int,
x: torch.Tensor,
rotary_cos_sin,
c: torch.Tensor,
k_prefix: Optional[torch.Tensor] = None,
v_prefix: Optional[torch.Tensor] = None,
):
"""
Run one DiT block on `x` (current block positions only) with an
optional cached K/V prefix.
Args:
layer_idx: index into self.model.blocks
x: [B, bs, d] current block hidden state
rotary_cos_sin: rotary cos/sin for positions block_start..block_end-1
c: [B, cond_dim] conditioning
k_prefix, v_prefix: [B, H, S_prefix, d_head] post-rotary cached K/V
(from earlier blocks). None means no prefix.
Returns:
x_out: [B, bs, d]
k_new: [B, H, bs, d_head] post-rotary K for current block
v_new: [B, H, bs, d_head] post-rotary V for current block
"""
layer = self.model.blocks[layer_idx]
B = x.shape[0]
H = layer.n_heads
dropout = layer.dropout
bds_fn = layer._bias_dropout_scale_fn()
(shift_msa, scale_msa, gate_msa,
shift_mlp, scale_mlp, gate_mlp) = layer.adaLN_modulation(c)[:, None].chunk(6, dim=2)
x_skip = x
x_normed = modulate_fused(layer.norm1(x), shift_msa, scale_msa)
qkv = layer.attn_qkv(x_normed)
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=H)
cos, sin = rotary_cos_sin
qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
q = qkv[:, :, 0].transpose(1, 2) # [B, H, bs, d_h]
k_new = qkv[:, :, 1].transpose(1, 2) # [B, H, bs, d_h]
v_new = qkv[:, :, 2].transpose(1, 2)
if k_prefix is not None:
k = torch.cat([k_prefix, k_new], dim=2)
v = torch.cat([v_prefix, v_new], dim=2)
else:
k = k_new
v = v_new
# No mask: current block may attend to all prefix (block-causal lookback)
# and to itself (bidirectional within block).
attn_out = F.scaled_dot_product_attention(q, k, v)
attn_out = rearrange(attn_out, "b h s d -> b s (h d)", b=B)
x = bds_fn(layer.attn_out(attn_out), None, gate_msa, x_skip, dropout)
x = bds_fn(
layer.mlp(modulate_fused(layer.norm2(x), shift_mlp, scale_mlp)),
None, gate_mlp, x, dropout,
)
return x, k_new, v_new
def _forward_block_cached(
self,
level_ids_cur: torch.Tensor,
value_ids_cur: torch.Tensor,
block_idx: int,
kv_cache: list,
is_clean: bool = False,
):
"""
Forward pass over a single block using cached prefix K/V.
Args:
level_ids_cur, value_ids_cur: [B, bs] current block state
block_idx: int, absolute block index (for pos/rotary)
kv_cache: list[(k_prefix, v_prefix) or (None, None)] per layer
is_clean: if True, use segment_embed(1) (clean half) to match
training's clean context. Used when capturing K/V for
finalized blocks and prompt warm-up.
Returns:
logits_cur: [B, bs, V] (mask column already set to -inf)
new_kv: list[(k_cur, v_cur)] per layer β€” caller appends to cache
"""
model = self.model
B, bs = level_ids_cur.shape
block_start = block_idx * self.block_size
block_end = block_start + bs
device = self.device
embs = self._build_mixed_embeddings(level_ids_cur, value_ids_cur) # [B, bs, d]
# Input projection (weights are self.dtype; embs already self.dtype).
x = model.input_proj(embs)
# Position embeddings for this block only.
block_idx_t = torch.full(
(bs,), block_idx, dtype=torch.long, device=device,
)
intra_pos = torch.arange(self.block_size, device=device)
# segment=0 for noisy (denoising rounds), segment=1 for clean (cache capture)
seg_val = 1 if is_clean else 0
seg_id = torch.full((bs,), seg_val, dtype=torch.long, device=device)
pos_emb = (
model.block_idx_embed(block_idx_t)
+ model.intra_pos_embed(intra_pos)
+ model.segment_embed(seg_id)
).unsqueeze(0).to(x.dtype)
x = x + pos_emb
c = model.cond_bias.unsqueeze(0).expand(B, -1).to(x.dtype)
# Rotary for absolute positions of this block.
position_ids = torch.arange(block_start, block_end, device=device)
rotary_cos_sin = model.rotary_emb(x, position_ids=position_ids)
new_kv = []
autocast_device = "cuda" if device.type == "cuda" else "cpu"
with torch.autocast(device_type=autocast_device, dtype=self.dtype):
for layer_idx in range(len(model.blocks)):
k_prefix, v_prefix = kv_cache[layer_idx]
x, k_cur, v_cur = self._run_layer_cached(
layer_idx, x, rotary_cos_sin, c,
k_prefix=k_prefix, v_prefix=v_prefix,
)
new_kv.append((k_cur, v_cur))
logits = model.output_layer(x, c) # [B, bs, rounded_leaf]
logits = logits[..., :self.vocab_size]
logits[..., self.mask_id] = float("-inf")
return logits, new_kv
@staticmethod
def _append_kv(kv_cache: list, new_kv: list) -> list:
"""Append per-layer new_kv to kv_cache along the sequence dim."""
out = []
for (kp, vp), (kn, vn) in zip(kv_cache, new_kv):
if kp is None:
out.append((kn, vn))
else:
out.append((torch.cat([kp, kn], dim=2),
torch.cat([vp, vn], dim=2)))
return out
# ───────────────────────────────────────────────────────────────────────
@torch.no_grad()
def generate(
self,
batch_size: Optional[int] = None,
prompt_ids: Optional[torch.Tensor] = None,
positions_per_step: int = 1,
return_intermediate: bool = False,
stop_on_eos: bool = True,
) -> dict:
"""
Block-by-block generation with KV cache and random per-round position
selection.
Within each block, rounds repeat until every position is leaf. Each
round runs one forward, computes the best strictly-finer target
(level, id) for every non-leaf position, then picks
`positions_per_step` random non-leaf positions per sample and applies
their transitions. The strict-descent schedule (pick the finest level
whose LUT-projected max-prob is highest) is unchanged.
Unconditional: pass `batch_size` (and leave `prompt_ids=None`); starts
from an all-mask sequence of length `self.max_seq_len`.
Conditional: pass `prompt_ids` with shape [B, P] where P is a multiple
of `block_size`; the first P positions are fixed as leaf tokens, the
remaining positions are generated block by block.
"""
block_size = self.block_size
device = self.device
total_len = self.max_seq_len
assert total_len % block_size == 0, (
f"max_seq_len ({total_len}) must be divisible by block_size "
f"({block_size})"
)
if prompt_ids is not None:
prompt_ids = prompt_ids.to(device=device, dtype=torch.long)
B, P = prompt_ids.shape
assert P % block_size == 0, (
f"prompt length P={P} must be a multiple of block_size={block_size}"
)
assert P < total_len, (
f"prompt length P={P} must be < total_len={total_len}"
)
start_block = P // block_size
else:
assert batch_size is not None, (
"Either batch_size (unconditional) or prompt_ids (conditional) "
"must be provided."
)
B = batch_size
P = 0
start_block = 0
# ── Initialize state: every position is mask; prompt positions set as leaf.
level_ids = torch.full(
(B, total_len), self.mask_level, dtype=torch.long, device=device,
)
value_ids = torch.zeros((B, total_len), dtype=torch.long, device=device)
if P > 0:
level_ids[:, :P] = 0
value_ids[:, :P] = prompt_ids
num_blocks = total_len // block_size
intermediate = [] if return_intermediate else None
finished = torch.zeros(B, dtype=torch.bool, device=device)
eos_id = getattr(self.tokenizer, "eos_token_id", None)
# ── KV cache: per-layer (k_prefix, v_prefix) for finalized blocks.
# Starts empty; we append block b's K/V after b is fully resolved,
# so when block b+1 starts the cache covers blocks 0..b.
num_layers = len(self.model.blocks)
kv_cache = [(None, None) for _ in range(num_layers)]
# ── Warm up KV cache over prompt blocks (all leaf, deterministic).
# Use is_clean=True: prompt blocks act as clean context for later blocks,
# matching training's clean half (segment=1).
for b in range(start_block):
bs0 = b * block_size
be0 = (b + 1) * block_size
_, new_kv = self._forward_block_cached(
level_ids[:, bs0:be0], value_ids[:, bs0:be0], b, kv_cache,
is_clean=True,
)
kv_cache = self._append_kv(kv_cache, new_kv)
# ── Block loop (skips prompt blocks). ──────────────────────────────
# Each round advances up to `positions_per_step` non-leaf positions by
# β‰₯1 level each (strict descent). Worst case every position needs K+1
# transitions β†’ cap at block_size * (K+1) rounds, which is slack.
rounds_cap_per_block = block_size * (self.K + 1)
total_steps = 0 # total denoising rounds across all generated blocks
for b in range(start_block, num_blocks):
block_start = b * block_size
block_end = (b + 1) * block_size
for _ in range(rounds_cap_per_block):
cur_level_block = level_ids[:, block_start:block_end] # [B, bs]
non_leaf_block = (cur_level_block > 0) # [B, bs]
if not non_leaf_block.any():
break
# 1) Forward pass on current block (cache holds blocks 0..b-1).
block_logits, _ = self._forward_block_cached(
level_ids[:, block_start:block_end],
value_ids[:, block_start:block_end],
b, kv_cache,
) # [B, bs, V]
# Compute raw (temperature=1) and temperature-sharpened leaf probs.
# p_leaf_raw / p_ancestor_raw are used for sampling; conf uses
# temp for leaf and raw+lambda for ancestor.
leaf_logits_fp = block_logits.float()
leaf_prob_raw = F.softmax(leaf_logits_fp, dim=-1) # [B, bs, V]
if self.leaf_temperature != 1.0:
leaf_prob_temp = F.softmax(
leaf_logits_fp / self.leaf_temperature, dim=-1,
) # [B, bs, V]
else:
leaf_prob_temp = leaf_prob_raw
# 2) Best strictly-finer target for every block position.
best_conf = torch.full(
(B, block_size), float("-inf"),
device=device, dtype=torch.float32,
)
best_level = torch.full(
(B, block_size), -1, device=device, dtype=torch.long,
)
best_id = torch.zeros(
(B, block_size), device=device, dtype=torch.long,
)
# Leaf target (l = 0): conf from temp-sharpened dist, sample
# from temp-sharpened dist.
leaf_conf = leaf_prob_temp.max(dim=-1).values # [B, bs]
leaf_id = torch.multinomial(
leaf_prob_temp.reshape(-1, leaf_prob_temp.shape[-1]),
num_samples=1,
).squeeze(-1).reshape(B, block_size) # [B, bs]
elig = cur_level_block > 0
upd = elig & (leaf_conf > best_conf)
best_conf = torch.where(upd, leaf_conf, best_conf)
best_level = torch.where(upd, torch.zeros_like(best_level), best_level)
best_id = torch.where(upd, leaf_id, best_id)
# Ancestor targets l = 1..K.
# Conf is max-prob over RAW cluster probs times Ξ»_l.
# Sample is drawn from RAW cluster probs.
for l in range(1, self.K + 1):
V_anc = self.W[l].shape[0]
cluster_prob_raw = leaf_prob_raw[..., :V_anc] @ self.W[l] # [B, bs, K_l]
conf_l = cluster_prob_raw.max(dim=-1).values # [B, bs]
conf_l = conf_l * self.level_lambdas[l]
id_l = torch.multinomial(
cluster_prob_raw.reshape(-1, cluster_prob_raw.shape[-1]),
num_samples=1,
).squeeze(-1).reshape(B, block_size) # [B, bs]
elig_l = cur_level_block > l
upd = elig_l & (conf_l > best_conf)
best_conf = torch.where(upd, conf_l, best_conf)
best_level = torch.where(
upd, torch.full_like(best_level, l), best_level,
)
best_id = torch.where(upd, id_l, best_id)
# 3) Randomly pick `positions_per_step` non-leaf positions per
# sample. Leaf positions get score = -inf so they never win a
# top-k slot; samples with fewer than k non-leaf positions
# drop the extra slots via the explicit non_leaf_block mask.
k = min(positions_per_step, block_size)
scores = torch.rand(B, block_size, device=device)
scores = torch.where(
non_leaf_block, scores, torch.full_like(scores, -1.0),
)
_, topk_idx = scores.topk(k, dim=-1) # [B, k]
selected = torch.zeros_like(non_leaf_block)
selected.scatter_(1, topk_idx, True)
apply_mask = selected & non_leaf_block # [B, bs]
level_ids[:, block_start:block_end] = torch.where(
apply_mask, best_level, cur_level_block,
)
value_ids[:, block_start:block_end] = torch.where(
apply_mask, best_id, value_ids[:, block_start:block_end],
)
if return_intermediate:
intermediate.append(
(level_ids.clone().cpu(), value_ids.clone().cpu())
)
total_steps += 1
# Safety net: force any lingering non-leaf positions to leaf.
# Use the same temperature-sharpened distribution for consistency.
block_level = level_ids[:, block_start:block_end]
non_leaf = (block_level > 0)
if non_leaf.any():
block_logits, _ = self._forward_block_cached(
level_ids[:, block_start:block_end],
value_ids[:, block_start:block_end],
b, kv_cache,
)
leaf_logits_fp = block_logits.float()
if self.leaf_temperature != 1.0:
leaf_logits_fp = leaf_logits_fp / self.leaf_temperature
leaf_prob_fallback = F.softmax(leaf_logits_fp, dim=-1)
leaf_id_fallback = torch.multinomial(
leaf_prob_fallback.reshape(-1, leaf_prob_fallback.shape[-1]),
num_samples=1,
).squeeze(-1).reshape(B, block_size)
level_ids[:, block_start:block_end] = torch.where(
non_leaf, torch.zeros_like(block_level), block_level,
)
value_ids[:, block_start:block_end] = torch.where(
non_leaf, leaf_id_fallback, value_ids[:, block_start:block_end],
)
# ── Finalize block b in the KV cache ───────────────────────────
# Run one more forward on the block's final (all-leaf) state to
# grab K/V that are consistent with the resolved tokens, then
# append to the cache so block b+1 can see block b.
# Use is_clean=True: finalized blocks serve as clean context for
# later blocks, matching training's clean half (segment=1).
_, new_kv = self._forward_block_cached(
level_ids[:, block_start:block_end],
value_ids[:, block_start:block_end],
b, kv_cache,
is_clean=True,
)
kv_cache = self._append_kv(kv_cache, new_kv)
if stop_on_eos and eos_id is not None:
block_vals = value_ids[:, block_start:block_end]
block_lvls = level_ids[:, block_start:block_end]
has_eos = ((block_lvls == 0) & (block_vals == eos_id)).any(dim=-1)
finished = finished | has_eos
if finished.all():
break
# ── Package output ──────────────────────────────────────────────────
# Every position is now leaf (level 0), so value_ids holds token ids.
result = {
"tokens": value_ids.cpu(),
"prompt_len": P,
"num_steps": total_steps,
}
if return_intermediate:
result["intermediate"] = intermediate
return result
# ─────────────────────────────────────────────────────────────────────────────
# Checkpoint / model plumbing
# ─────────────────────────────────────────────────────────────────────────────
def _unwrap(model):
"""Peel DDP (.module) and torch.compile (._orig_mod) wrappers."""
while True:
if hasattr(model, "_orig_mod"):
model = model._orig_mod
elif hasattr(model, "module"):
model = model.module
else:
return model
def load_config(path: str) -> dict:
with open(path) as f:
return yaml.safe_load(f)
def build_tokenizer(config: dict):
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
ROOT / "tokenizers" / "gpt2",
local_files_only=True,
)
if tok.eos_token is None:
tok.add_special_tokens({"eos_token": "<|endoftext|>"})
if tok.bos_token is None:
tok.bos_token = tok.eos_token
if tok.pad_token is None:
tok.pad_token = tok.eos_token
if tok.mask_token_id is None:
tok.add_special_tokens({"mask_token": "[MASK]"})
config["model"]["vocab_size"] = len(tok)
if "level_sizes" in config["model"]:
config["model"]["level_sizes"][0] = len(tok)
return tok
def build_ancestor_table(config: dict, device, embed_dim: int) -> AncestorTable:
"""Mirror of train_sad.build_ancestor_table β€” load fixed LUT (and proto)
so the returned module has the right shape for ckpt state_dict loading."""
ancestor_cfg = config.get("ancestor", {})
script_dir = ROOT
lut_path = ancestor_cfg.get("lut_path", None)
if lut_path is None:
# Debug path: random LUT. Uses the training seed so the random LUT
# lines up across train/infer β€” checkpoint's state_dict will overwrite
# the learnable embeddings anyway.
vocab_size = config["model"]["vocab_size"]
K = ancestor_cfg.get("num_clusters", 64)
top_k = ancestor_cfg.get("top_k", 3)
seed = config.get("training", {}).get("seed", 42)
g = torch.Generator().manual_seed(seed)
indices = torch.randint(0, K, (vocab_size, top_k), generator=g)
raw_w = torch.rand(vocab_size, top_k, generator=g)
probs = raw_w / raw_w.sum(dim=-1, keepdim=True)
init_emb = torch.randn(K, embed_dim, generator=g) * 0.02
return AncestorTable(
lut_indices=[indices],
lut_probs=[probs],
init_embeddings=[init_emb],
).to(device)
lut_path = Path(lut_path) if Path(lut_path).is_absolute() else script_dir / lut_path
proto_path = ancestor_cfg.get("proto_path", None)
if proto_path is not None:
proto_path = Path(proto_path) if Path(proto_path).is_absolute() else script_dir / proto_path
table = AncestorTable.from_files(
lut_path=lut_path, proto_path=proto_path,
embed_dim=embed_dim, device=device,
)
return table.to(device)
def build_model(config: dict, device: torch.device) -> SADModel:
mc = config["model"]
model = SADModel(
vocab_size=mc["vocab_size"],
hidden_size=mc["hidden_size"],
n_blocks=mc["n_blocks"],
n_heads=mc["n_heads"],
cond_dim=mc["cond_dim"],
max_seq_len=mc["max_seq_len"],
block_size=mc.get("block_size", 8),
dropout=mc.get("dropout", 0.0),
num_levels=mc.get("num_levels", 2),
level_sizes=mc.get("level_sizes"),
tie_weights=mc.get("tie_weights", False),
).to(device)
return model
# ─────────────────────────────────────────────────────────────────────────────
# CLI
# ─────────────────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--checkpoint", type=str, required=True)
p.add_argument("--config", type=str, default="configs/sad_owt.yaml")
p.add_argument("--num_samples", type=int, default=1)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
p.add_argument("--stop_on_eos", action="store_true", default=True)
p.add_argument("--mode", type=str, default="unconditional",
choices=["unconditional", "conditional"],
help="unconditional: start from all-mask. "
"conditional: take a block from the training set as the first block(s).")
p.add_argument("--prompt_blocks", type=int, default=1,
help="(conditional) number of leading blocks taken from the training data.")
p.add_argument("--data_seed", type=int, default=0,
help="(conditional) seed for shuffling the training split when picking a sample.")
p.add_argument("--positions_per_step", type=int, default=1,
help="Number of random non-leaf positions to advance per "
"denoising round within a block.")
p.add_argument("--level_lambdas", type=str, default=None,
help="Comma-separated K floats in [0, 1], one per ancestor "
"level l = 1..K (e.g. '1.0,0.8,0.5'). Multiplies the "
"level's max-prob conf before the cross-level argmax. "
"Ξ»_l < 1 biases the schedule away from level l; "
"Ξ»_l = 0 disables it. Default: all 1.0 (no change).")
p.add_argument("--leaf_temperature", type=float, default=1.0,
help="Temperature applied to leaf logits before softmax. "
"Values < 1.0 sharpen p_leaf, which is then used for "
"both leaf multinomial sampling and ancestor projection. "
"Default 1.0 (no sharpening).")
return p.parse_args()
def resolve_dtype(name: str) -> torch.dtype:
return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name]
def main():
args = parse_args()
torch.manual_seed(args.seed)
device = torch.device(args.device)
dtype = resolve_dtype(args.dtype)
config = load_config(args.config)
tokenizer = build_tokenizer(config)
# ── Build + load model ─────────────────────────────────────────────────
model = build_model(config, device).to(dtype)
ckpt = torch.load(args.checkpoint, map_location=device)
raw_state = ckpt.get("model", ckpt)
_unwrap(model).load_state_dict(raw_state, strict=False)
model.eval()
print(f"Loaded checkpoint: {args.checkpoint} (step={ckpt.get('step', '?')})")
# ── Build + load ancestor table ────────────────────────────────────────
# Fixed LUT comes from config (same file as training); learnable ancestor
# embeddings come from the checkpoint. load_state_dict overwrites both
# buffers (LUT, W_l) and parameters (ancestor_embeddings) to match training
# exactly.
ancestor_table = build_ancestor_table(
config, device, embed_dim=config["model"]["hidden_size"],
)
assert "ancestor_table" in ckpt, (
"Checkpoint has no 'ancestor_table' entry β€” cannot run hierarchical "
"inference. Re-train with train_sad.py or use an older inference "
"script that ignores ancestors."
)
ancestor_table.load_state_dict(ckpt["ancestor_table"])
ancestor_table.to(device=device, dtype=dtype).eval()
print(f"Loaded ancestor table: {ancestor_table.num_levels} ancestor level(s)")
level_lambdas = None
if args.level_lambdas:
level_lambdas = [float(x) for x in args.level_lambdas.split(",")]
sampler = BlockDiffusionSampler(
model=_unwrap(model),
ancestor_table=ancestor_table,
tokenizer=tokenizer,
device=device,
dtype=dtype,
level_lambdas=level_lambdas,
leaf_temperature=args.leaf_temperature,
)
print(f"level_lambdas (per ancestor level l=1..K) = "
f"{sampler.level_lambdas[1:]}")
print(f"leaf_temperature = {sampler.leaf_temperature}")
# ── Optionally load a prompt from the training data ────────────────────
prompt_ids = None
if args.mode == "conditional":
data_cfg = config.get("data", {})
seq_len = config["model"]["max_seq_len"]
block_size = config["model"]["block_size"]
prompt_len = args.prompt_blocks * block_size
assert prompt_len < seq_len, (
f"prompt_blocks * block_size = {prompt_len} must be < max_seq_len = {seq_len}"
)
# Resolve relative cache_dir against the sad/ repo root (scripts/..), so
# the script works regardless of cwd (training ran from sad/).
cache_dir = data_cfg.get("cache_dir", None)
if cache_dir is not None and not Path(cache_dir).is_absolute():
repo_root = ROOT
candidate = repo_root / cache_dir
if candidate.exists():
cache_dir = str(candidate)
loader = build_owt_dataloader(
tokenizer,
split="train[:-100000]",
seq_len=seq_len,
batch_size=args.num_samples,
num_workers=0,
cache_dir=cache_dir,
seed=args.data_seed,
mode=data_cfg.get("mode", "subsample"),
shard_across_ranks=False,
)
batch = next(iter(loader))
prompt_ids = batch["input_ids"][:args.num_samples, :prompt_len].to(device)
print(f"Loaded conditional prompt from training data: "
f"shape={tuple(prompt_ids.shape)} (prompt_blocks={args.prompt_blocks})")
print(f"Sampling {args.num_samples} sequences ({args.mode}) "
f"length={config['model']['max_seq_len']}, "
f"random positions_per_step={args.positions_per_step}")
out = sampler.generate(
batch_size=args.num_samples if prompt_ids is None else None,
prompt_ids=prompt_ids,
positions_per_step=args.positions_per_step,
stop_on_eos=args.stop_on_eos,
)
# ── Decode & print ─────────────────────────────────────────────────────
P = out.get("prompt_len", 0)
print("\n" + "=" * 72)
for i, ids in enumerate(out["tokens"]):
ids_list = ids.tolist()
print(f"[Sample {i + 1}]")
if P > 0:
prompt_text = tokenizer.decode(ids_list[:P], skip_special_tokens=True)
gen_text = tokenizer.decode(ids_list[P:], skip_special_tokens=True)
print(f"<prompt ({P} tok)> {prompt_text}")
print(f"<generated> {gen_text}")
else:
print(tokenizer.decode(ids_list, skip_special_tokens=True))
print()
if __name__ == "__main__":
main()