#!/usr/bin/env python3 """ inference_sad.py – Block-wise hierarchical diffusion sampling from a trained SADModel. Generation proceeds block by block left-to-right. Within each block, a small random subset of non-leaf positions is advanced each round to some strictly finer level in the hierarchy mask (level K+1) > ancestors (K, …, 1) > leaf (level 0) A transition may jump any number of levels (e.g. mask → leaf directly, or ancestor l → ancestor l' with l' < l, or ancestor → leaf) as long as the new level is strictly finer than the current one — never stay, never revert. Rounds repeat until every position in the block is leaf; then the next block begins. Each denoising round: 1. One forward pass on the current block (K/V cache holds earlier blocks). 2. Softmax the leaf logits and project through the fixed LUT (`AncestorTable.projection_matrix`) into every strictly-finer ancestor level; max over each distribution gives per-level confidence (used to rank candidate levels). For ancestor levels the conf is multiplied by a per-level scalar λ_l ∈ [0, 1] before the cross-level comparison (smaller λ_l biases the schedule away from that ancestor level — λ_l = 0 disables it; the default λ = 1 reproduces the original behavior). Leaf (l=0) is never scaled. The target id is then produced per-level: - leaf level (l=0): argmax over the leaf distribution (deterministic) - ancestor level (l≥1): multinomial sampling from the cluster dist. (stochastic) Cross-level confidence is always computed from the original (temperature=1) softmax so that leaf and ancestor probabilities are comparable. 3. Randomly pick `positions_per_step` non-leaf positions per sample and transition each to its best strictly-finer level. Finalized blocks' K/V are cached so forwards only recompute the current block. Usage: python scripts/inference_sad.py \\ --config configs/sad_owt.yaml \\ --checkpoint outputs/sad/latest.pt \\ --num_samples 4 """ from __future__ import annotations import argparse import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] # sad/ from typing import Optional import torch import torch.nn.functional as F import yaml sys.path.insert(0, str(ROOT)) from src.models.sad_model import SADModel from src.models.dit_components import apply_rotary_pos_emb, modulate_fused from src.diffusion.ancestor_table import AncestorTable from src.data import build_owt_dataloader from einops import rearrange # ───────────────────────────────────────────────────────────────────────────── # Sampler # ───────────────────────────────────────────────────────────────────────────── class BlockDiffusionSampler: """ Block-wise hierarchical diffusion sampler for SADModel. State per position is (level, value): level = 0 → leaf token; value = token id level ∈ [1, K] → ancestor at level l; value = cluster id in K_l level = K + 1 → mask Per-block denoising loop (random position selection, strict-descent schedule): Until every position in the block is leaf: 1. Forward pass on the current block (cache holds earlier blocks). 2. Vectorized over all block positions, project the leaf softmax through the LUT: leaf target (l=0): prob = softmax(logits) ancestor target (l≥1): prob = softmax(logits) @ W_l [V, K_l] Each candidate level contributes (conf, id): conf is the max-prob (used only to compare levels). The id is argmax if the level is leaf (l=0) and a multinomial draw if the level is ancestor (l≥1) — so only the final landing in the leaf layer is deterministic, while intermediate ancestor steps are stochastic. Only levels strictly finer than the position's current level are eligible — so mask → leaf (skipping every ancestor) is a legal transition, as is any multi-level jump. The eligible level with the highest confidence wins. 3. Randomly pick `positions_per_step` non-leaf positions per sample and apply the selected transition at those positions only. """ def __init__( self, model: SADModel, ancestor_table: AncestorTable, tokenizer, device: torch.device, dtype: torch.dtype = torch.bfloat16, level_lambdas: Optional[list] = None, leaf_temperature: float = 1.0, ): """ level_lambdas: length-K list of floats in [0, 1]. λ_l (for ancestor level l = 1..K) multiplies that level's max-prob conf before the cross-level argmax that picks the winning target. Leaf (l=0) is never scaled. None → all ones (original behavior). leaf_temperature: temperature applied to leaf logits before softmax. Values < 1.0 sharpen the leaf distribution (higher confidence), which is then used for both leaf sampling and ancestor projection. Default 1.0 (no temperature scaling). """ self.model = model self.ancestor_table = ancestor_table self.tokenizer = tokenizer self.device = device self.dtype = dtype self.leaf_temperature = float(leaf_temperature) self.block_size: int = model.block_size self.max_seq_len: int = model.max_seq_len self.vocab_size: int = model.vocab_size self.mask_id: int = tokenizer.mask_token_id assert self.mask_id is not None, "tokenizer must have mask_token_id" self.K: int = ancestor_table.num_levels # number of ancestor levels self.mask_level: int = self.K + 1 if level_lambdas is None: level_lambdas = [1.0] * self.K assert len(level_lambdas) == self.K, ( f"level_lambdas must have length K={self.K}, got {len(level_lambdas)}" ) for x in level_lambdas: assert 0.0 <= float(x) <= 1.0, f"each λ must be in [0, 1], got {x}" # 1-indexed: self.level_lambdas[l] is λ_l for ancestor level l ∈ [1, K] self.level_lambdas = [None] + [float(x) for x in level_lambdas] # Leaf embedding table (tied with output head — read-only view). self.leaf_emb = model.get_leaf_embeddings().to(device=device, dtype=dtype).detach() self.mask_emb = self.leaf_emb[self.mask_id] # [d] # Ancestor embeddings per level: fed into the model, so keep them in # self.dtype to match model weights. self.anc_embs = [None] + [ ancestor_table.ancestor_embeddings(l).to(device=device, dtype=dtype).detach() for l in range(1, self.K + 1) ] # LUT projection matrices W_l: used only on the scoring side (fp32). # Fixed buffers, no grad, so fp32 storage is cheap. self.W = [None] + [ ancestor_table.projection_matrix(l).to(device=device, dtype=torch.float32).detach() for l in range(1, self.K + 1) ] # ─────────────────────────────────────────────────────────────────────── def _build_mixed_embeddings( self, level_ids: torch.Tensor, value_ids: torch.Tensor, ) -> torch.Tensor: """ Build [B, S, d] input embeddings from per-position (level, value). Mirrors NoisyStateBuilder.build_noisy_embeddings so inference-time inputs match the training distribution. """ B, S = level_ids.shape d = self.leaf_emb.shape[-1] embs = torch.empty(B, S, d, device=self.device, dtype=self.dtype) # leaf (level 0) — leaf_emb[value] m0 = (level_ids == 0) if m0.any(): embs[m0] = self.leaf_emb[value_ids[m0]] # mask (level K+1) — leaf_emb[mask_id] mM = (level_ids == self.mask_level) if mM.any(): embs[mM] = self.mask_emb # ancestor levels 1..K — anc_embs[l][value] for l in range(1, self.K + 1): ml = (level_ids == l) if ml.any(): embs[ml] = self.anc_embs[l][value_ids[ml]] return embs # ─────────────────────────────────────────────────────────────────────── # KV-cache–aware forward. The key observation: under the block-causal mask, # the K/V produced at positions in finalized (leaf) earlier blocks are # deterministic and never change. So we compute them once per block and # reuse them across all denoising rounds of the current block. # # This method inlines DDiTBlockWithMask.forward so we can (a) accept a K/V # prefix cache, (b) avoid recomputing Q/K/V for earlier blocks. When # k_prefix is None it also serves as an uncached single-block pass (used # for prompt blocks and the final K/V capture). # ─────────────────────────────────────────────────────────────────────── def _run_layer_cached( self, layer_idx: int, x: torch.Tensor, rotary_cos_sin, c: torch.Tensor, k_prefix: Optional[torch.Tensor] = None, v_prefix: Optional[torch.Tensor] = None, ): """ Run one DiT block on `x` (current block positions only) with an optional cached K/V prefix. Args: layer_idx: index into self.model.blocks x: [B, bs, d] current block hidden state rotary_cos_sin: rotary cos/sin for positions block_start..block_end-1 c: [B, cond_dim] conditioning k_prefix, v_prefix: [B, H, S_prefix, d_head] post-rotary cached K/V (from earlier blocks). None means no prefix. Returns: x_out: [B, bs, d] k_new: [B, H, bs, d_head] post-rotary K for current block v_new: [B, H, bs, d_head] post-rotary V for current block """ layer = self.model.blocks[layer_idx] B = x.shape[0] H = layer.n_heads dropout = layer.dropout bds_fn = layer._bias_dropout_scale_fn() (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = layer.adaLN_modulation(c)[:, None].chunk(6, dim=2) x_skip = x x_normed = modulate_fused(layer.norm1(x), shift_msa, scale_msa) qkv = layer.attn_qkv(x_normed) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=H) cos, sin = rotary_cos_sin qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)) q = qkv[:, :, 0].transpose(1, 2) # [B, H, bs, d_h] k_new = qkv[:, :, 1].transpose(1, 2) # [B, H, bs, d_h] v_new = qkv[:, :, 2].transpose(1, 2) if k_prefix is not None: k = torch.cat([k_prefix, k_new], dim=2) v = torch.cat([v_prefix, v_new], dim=2) else: k = k_new v = v_new # No mask: current block may attend to all prefix (block-causal lookback) # and to itself (bidirectional within block). attn_out = F.scaled_dot_product_attention(q, k, v) attn_out = rearrange(attn_out, "b h s d -> b s (h d)", b=B) x = bds_fn(layer.attn_out(attn_out), None, gate_msa, x_skip, dropout) x = bds_fn( layer.mlp(modulate_fused(layer.norm2(x), shift_mlp, scale_mlp)), None, gate_mlp, x, dropout, ) return x, k_new, v_new def _forward_block_cached( self, level_ids_cur: torch.Tensor, value_ids_cur: torch.Tensor, block_idx: int, kv_cache: list, is_clean: bool = False, ): """ Forward pass over a single block using cached prefix K/V. Args: level_ids_cur, value_ids_cur: [B, bs] current block state block_idx: int, absolute block index (for pos/rotary) kv_cache: list[(k_prefix, v_prefix) or (None, None)] per layer is_clean: if True, use segment_embed(1) (clean half) to match training's clean context. Used when capturing K/V for finalized blocks and prompt warm-up. Returns: logits_cur: [B, bs, V] (mask column already set to -inf) new_kv: list[(k_cur, v_cur)] per layer — caller appends to cache """ model = self.model B, bs = level_ids_cur.shape block_start = block_idx * self.block_size block_end = block_start + bs device = self.device embs = self._build_mixed_embeddings(level_ids_cur, value_ids_cur) # [B, bs, d] # Input projection (weights are self.dtype; embs already self.dtype). x = model.input_proj(embs) # Position embeddings for this block only. block_idx_t = torch.full( (bs,), block_idx, dtype=torch.long, device=device, ) intra_pos = torch.arange(self.block_size, device=device) # segment=0 for noisy (denoising rounds), segment=1 for clean (cache capture) seg_val = 1 if is_clean else 0 seg_id = torch.full((bs,), seg_val, dtype=torch.long, device=device) pos_emb = ( model.block_idx_embed(block_idx_t) + model.intra_pos_embed(intra_pos) + model.segment_embed(seg_id) ).unsqueeze(0).to(x.dtype) x = x + pos_emb c = model.cond_bias.unsqueeze(0).expand(B, -1).to(x.dtype) # Rotary for absolute positions of this block. position_ids = torch.arange(block_start, block_end, device=device) rotary_cos_sin = model.rotary_emb(x, position_ids=position_ids) new_kv = [] autocast_device = "cuda" if device.type == "cuda" else "cpu" with torch.autocast(device_type=autocast_device, dtype=self.dtype): for layer_idx in range(len(model.blocks)): k_prefix, v_prefix = kv_cache[layer_idx] x, k_cur, v_cur = self._run_layer_cached( layer_idx, x, rotary_cos_sin, c, k_prefix=k_prefix, v_prefix=v_prefix, ) new_kv.append((k_cur, v_cur)) logits = model.output_layer(x, c) # [B, bs, rounded_leaf] logits = logits[..., :self.vocab_size] logits[..., self.mask_id] = float("-inf") return logits, new_kv @staticmethod def _append_kv(kv_cache: list, new_kv: list) -> list: """Append per-layer new_kv to kv_cache along the sequence dim.""" out = [] for (kp, vp), (kn, vn) in zip(kv_cache, new_kv): if kp is None: out.append((kn, vn)) else: out.append((torch.cat([kp, kn], dim=2), torch.cat([vp, vn], dim=2))) return out # ─────────────────────────────────────────────────────────────────────── @torch.no_grad() def generate( self, batch_size: Optional[int] = None, prompt_ids: Optional[torch.Tensor] = None, positions_per_step: int = 1, return_intermediate: bool = False, stop_on_eos: bool = True, ) -> dict: """ Block-by-block generation with KV cache and random per-round position selection. Within each block, rounds repeat until every position is leaf. Each round runs one forward, computes the best strictly-finer target (level, id) for every non-leaf position, then picks `positions_per_step` random non-leaf positions per sample and applies their transitions. The strict-descent schedule (pick the finest level whose LUT-projected max-prob is highest) is unchanged. Unconditional: pass `batch_size` (and leave `prompt_ids=None`); starts from an all-mask sequence of length `self.max_seq_len`. Conditional: pass `prompt_ids` with shape [B, P] where P is a multiple of `block_size`; the first P positions are fixed as leaf tokens, the remaining positions are generated block by block. """ block_size = self.block_size device = self.device total_len = self.max_seq_len assert total_len % block_size == 0, ( f"max_seq_len ({total_len}) must be divisible by block_size " f"({block_size})" ) if prompt_ids is not None: prompt_ids = prompt_ids.to(device=device, dtype=torch.long) B, P = prompt_ids.shape assert P % block_size == 0, ( f"prompt length P={P} must be a multiple of block_size={block_size}" ) assert P < total_len, ( f"prompt length P={P} must be < total_len={total_len}" ) start_block = P // block_size else: assert batch_size is not None, ( "Either batch_size (unconditional) or prompt_ids (conditional) " "must be provided." ) B = batch_size P = 0 start_block = 0 # ── Initialize state: every position is mask; prompt positions set as leaf. level_ids = torch.full( (B, total_len), self.mask_level, dtype=torch.long, device=device, ) value_ids = torch.zeros((B, total_len), dtype=torch.long, device=device) if P > 0: level_ids[:, :P] = 0 value_ids[:, :P] = prompt_ids num_blocks = total_len // block_size intermediate = [] if return_intermediate else None finished = torch.zeros(B, dtype=torch.bool, device=device) eos_id = getattr(self.tokenizer, "eos_token_id", None) # ── KV cache: per-layer (k_prefix, v_prefix) for finalized blocks. # Starts empty; we append block b's K/V after b is fully resolved, # so when block b+1 starts the cache covers blocks 0..b. num_layers = len(self.model.blocks) kv_cache = [(None, None) for _ in range(num_layers)] # ── Warm up KV cache over prompt blocks (all leaf, deterministic). # Use is_clean=True: prompt blocks act as clean context for later blocks, # matching training's clean half (segment=1). for b in range(start_block): bs0 = b * block_size be0 = (b + 1) * block_size _, new_kv = self._forward_block_cached( level_ids[:, bs0:be0], value_ids[:, bs0:be0], b, kv_cache, is_clean=True, ) kv_cache = self._append_kv(kv_cache, new_kv) # ── Block loop (skips prompt blocks). ────────────────────────────── # Each round advances up to `positions_per_step` non-leaf positions by # ≥1 level each (strict descent). Worst case every position needs K+1 # transitions → cap at block_size * (K+1) rounds, which is slack. rounds_cap_per_block = block_size * (self.K + 1) total_steps = 0 # total denoising rounds across all generated blocks for b in range(start_block, num_blocks): block_start = b * block_size block_end = (b + 1) * block_size for _ in range(rounds_cap_per_block): cur_level_block = level_ids[:, block_start:block_end] # [B, bs] non_leaf_block = (cur_level_block > 0) # [B, bs] if not non_leaf_block.any(): break # 1) Forward pass on current block (cache holds blocks 0..b-1). block_logits, _ = self._forward_block_cached( level_ids[:, block_start:block_end], value_ids[:, block_start:block_end], b, kv_cache, ) # [B, bs, V] # Compute raw (temperature=1) and temperature-sharpened leaf probs. # p_leaf_raw / p_ancestor_raw are used for sampling; conf uses # temp for leaf and raw+lambda for ancestor. leaf_logits_fp = block_logits.float() leaf_prob_raw = F.softmax(leaf_logits_fp, dim=-1) # [B, bs, V] if self.leaf_temperature != 1.0: leaf_prob_temp = F.softmax( leaf_logits_fp / self.leaf_temperature, dim=-1, ) # [B, bs, V] else: leaf_prob_temp = leaf_prob_raw # 2) Best strictly-finer target for every block position. best_conf = torch.full( (B, block_size), float("-inf"), device=device, dtype=torch.float32, ) best_level = torch.full( (B, block_size), -1, device=device, dtype=torch.long, ) best_id = torch.zeros( (B, block_size), device=device, dtype=torch.long, ) # Leaf target (l = 0): conf from temp-sharpened dist, sample # from temp-sharpened dist. leaf_conf = leaf_prob_temp.max(dim=-1).values # [B, bs] leaf_id = torch.multinomial( leaf_prob_temp.reshape(-1, leaf_prob_temp.shape[-1]), num_samples=1, ).squeeze(-1).reshape(B, block_size) # [B, bs] elig = cur_level_block > 0 upd = elig & (leaf_conf > best_conf) best_conf = torch.where(upd, leaf_conf, best_conf) best_level = torch.where(upd, torch.zeros_like(best_level), best_level) best_id = torch.where(upd, leaf_id, best_id) # Ancestor targets l = 1..K. # Conf is max-prob over RAW cluster probs times λ_l. # Sample is drawn from RAW cluster probs. for l in range(1, self.K + 1): V_anc = self.W[l].shape[0] cluster_prob_raw = leaf_prob_raw[..., :V_anc] @ self.W[l] # [B, bs, K_l] conf_l = cluster_prob_raw.max(dim=-1).values # [B, bs] conf_l = conf_l * self.level_lambdas[l] id_l = torch.multinomial( cluster_prob_raw.reshape(-1, cluster_prob_raw.shape[-1]), num_samples=1, ).squeeze(-1).reshape(B, block_size) # [B, bs] elig_l = cur_level_block > l upd = elig_l & (conf_l > best_conf) best_conf = torch.where(upd, conf_l, best_conf) best_level = torch.where( upd, torch.full_like(best_level, l), best_level, ) best_id = torch.where(upd, id_l, best_id) # 3) Randomly pick `positions_per_step` non-leaf positions per # sample. Leaf positions get score = -inf so they never win a # top-k slot; samples with fewer than k non-leaf positions # drop the extra slots via the explicit non_leaf_block mask. k = min(positions_per_step, block_size) scores = torch.rand(B, block_size, device=device) scores = torch.where( non_leaf_block, scores, torch.full_like(scores, -1.0), ) _, topk_idx = scores.topk(k, dim=-1) # [B, k] selected = torch.zeros_like(non_leaf_block) selected.scatter_(1, topk_idx, True) apply_mask = selected & non_leaf_block # [B, bs] level_ids[:, block_start:block_end] = torch.where( apply_mask, best_level, cur_level_block, ) value_ids[:, block_start:block_end] = torch.where( apply_mask, best_id, value_ids[:, block_start:block_end], ) if return_intermediate: intermediate.append( (level_ids.clone().cpu(), value_ids.clone().cpu()) ) total_steps += 1 # Safety net: force any lingering non-leaf positions to leaf. # Use the same temperature-sharpened distribution for consistency. block_level = level_ids[:, block_start:block_end] non_leaf = (block_level > 0) if non_leaf.any(): block_logits, _ = self._forward_block_cached( level_ids[:, block_start:block_end], value_ids[:, block_start:block_end], b, kv_cache, ) leaf_logits_fp = block_logits.float() if self.leaf_temperature != 1.0: leaf_logits_fp = leaf_logits_fp / self.leaf_temperature leaf_prob_fallback = F.softmax(leaf_logits_fp, dim=-1) leaf_id_fallback = torch.multinomial( leaf_prob_fallback.reshape(-1, leaf_prob_fallback.shape[-1]), num_samples=1, ).squeeze(-1).reshape(B, block_size) level_ids[:, block_start:block_end] = torch.where( non_leaf, torch.zeros_like(block_level), block_level, ) value_ids[:, block_start:block_end] = torch.where( non_leaf, leaf_id_fallback, value_ids[:, block_start:block_end], ) # ── Finalize block b in the KV cache ─────────────────────────── # Run one more forward on the block's final (all-leaf) state to # grab K/V that are consistent with the resolved tokens, then # append to the cache so block b+1 can see block b. # Use is_clean=True: finalized blocks serve as clean context for # later blocks, matching training's clean half (segment=1). _, new_kv = self._forward_block_cached( level_ids[:, block_start:block_end], value_ids[:, block_start:block_end], b, kv_cache, is_clean=True, ) kv_cache = self._append_kv(kv_cache, new_kv) if stop_on_eos and eos_id is not None: block_vals = value_ids[:, block_start:block_end] block_lvls = level_ids[:, block_start:block_end] has_eos = ((block_lvls == 0) & (block_vals == eos_id)).any(dim=-1) finished = finished | has_eos if finished.all(): break # ── Package output ────────────────────────────────────────────────── # Every position is now leaf (level 0), so value_ids holds token ids. result = { "tokens": value_ids.cpu(), "prompt_len": P, "num_steps": total_steps, } if return_intermediate: result["intermediate"] = intermediate return result # ───────────────────────────────────────────────────────────────────────────── # Checkpoint / model plumbing # ───────────────────────────────────────────────────────────────────────────── def _unwrap(model): """Peel DDP (.module) and torch.compile (._orig_mod) wrappers.""" while True: if hasattr(model, "_orig_mod"): model = model._orig_mod elif hasattr(model, "module"): model = model.module else: return model def load_config(path: str) -> dict: with open(path) as f: return yaml.safe_load(f) def build_tokenizer(config: dict): from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained( ROOT / "tokenizers" / "gpt2", local_files_only=True, ) if tok.eos_token is None: tok.add_special_tokens({"eos_token": "<|endoftext|>"}) if tok.bos_token is None: tok.bos_token = tok.eos_token if tok.pad_token is None: tok.pad_token = tok.eos_token if tok.mask_token_id is None: tok.add_special_tokens({"mask_token": "[MASK]"}) config["model"]["vocab_size"] = len(tok) if "level_sizes" in config["model"]: config["model"]["level_sizes"][0] = len(tok) return tok def build_ancestor_table(config: dict, device, embed_dim: int) -> AncestorTable: """Mirror of train_sad.build_ancestor_table — load fixed LUT (and proto) so the returned module has the right shape for ckpt state_dict loading.""" ancestor_cfg = config.get("ancestor", {}) script_dir = ROOT lut_path = ancestor_cfg.get("lut_path", None) if lut_path is None: # Debug path: random LUT. Uses the training seed so the random LUT # lines up across train/infer — checkpoint's state_dict will overwrite # the learnable embeddings anyway. vocab_size = config["model"]["vocab_size"] K = ancestor_cfg.get("num_clusters", 64) top_k = ancestor_cfg.get("top_k", 3) seed = config.get("training", {}).get("seed", 42) g = torch.Generator().manual_seed(seed) indices = torch.randint(0, K, (vocab_size, top_k), generator=g) raw_w = torch.rand(vocab_size, top_k, generator=g) probs = raw_w / raw_w.sum(dim=-1, keepdim=True) init_emb = torch.randn(K, embed_dim, generator=g) * 0.02 return AncestorTable( lut_indices=[indices], lut_probs=[probs], init_embeddings=[init_emb], ).to(device) lut_path = Path(lut_path) if Path(lut_path).is_absolute() else script_dir / lut_path proto_path = ancestor_cfg.get("proto_path", None) if proto_path is not None: proto_path = Path(proto_path) if Path(proto_path).is_absolute() else script_dir / proto_path table = AncestorTable.from_files( lut_path=lut_path, proto_path=proto_path, embed_dim=embed_dim, device=device, ) return table.to(device) def build_model(config: dict, device: torch.device) -> SADModel: mc = config["model"] model = SADModel( vocab_size=mc["vocab_size"], hidden_size=mc["hidden_size"], n_blocks=mc["n_blocks"], n_heads=mc["n_heads"], cond_dim=mc["cond_dim"], max_seq_len=mc["max_seq_len"], block_size=mc.get("block_size", 8), dropout=mc.get("dropout", 0.0), num_levels=mc.get("num_levels", 2), level_sizes=mc.get("level_sizes"), tie_weights=mc.get("tie_weights", False), ).to(device) return model # ───────────────────────────────────────────────────────────────────────────── # CLI # ───────────────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser() p.add_argument("--checkpoint", type=str, required=True) p.add_argument("--config", type=str, default="configs/sad_owt.yaml") p.add_argument("--num_samples", type=int, default=1) p.add_argument("--seed", type=int, default=42) p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) p.add_argument("--stop_on_eos", action="store_true", default=True) p.add_argument("--mode", type=str, default="unconditional", choices=["unconditional", "conditional"], help="unconditional: start from all-mask. " "conditional: take a block from the training set as the first block(s).") p.add_argument("--prompt_blocks", type=int, default=1, help="(conditional) number of leading blocks taken from the training data.") p.add_argument("--data_seed", type=int, default=0, help="(conditional) seed for shuffling the training split when picking a sample.") p.add_argument("--positions_per_step", type=int, default=1, help="Number of random non-leaf positions to advance per " "denoising round within a block.") p.add_argument("--level_lambdas", type=str, default=None, help="Comma-separated K floats in [0, 1], one per ancestor " "level l = 1..K (e.g. '1.0,0.8,0.5'). Multiplies the " "level's max-prob conf before the cross-level argmax. " "λ_l < 1 biases the schedule away from level l; " "λ_l = 0 disables it. Default: all 1.0 (no change).") p.add_argument("--leaf_temperature", type=float, default=1.0, help="Temperature applied to leaf logits before softmax. " "Values < 1.0 sharpen p_leaf, which is then used for " "both leaf multinomial sampling and ancestor projection. " "Default 1.0 (no sharpening).") return p.parse_args() def resolve_dtype(name: str) -> torch.dtype: return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name] def main(): args = parse_args() torch.manual_seed(args.seed) device = torch.device(args.device) dtype = resolve_dtype(args.dtype) config = load_config(args.config) tokenizer = build_tokenizer(config) # ── Build + load model ───────────────────────────────────────────────── model = build_model(config, device).to(dtype) ckpt = torch.load(args.checkpoint, map_location=device) raw_state = ckpt.get("model", ckpt) _unwrap(model).load_state_dict(raw_state, strict=False) model.eval() print(f"Loaded checkpoint: {args.checkpoint} (step={ckpt.get('step', '?')})") # ── Build + load ancestor table ──────────────────────────────────────── # Fixed LUT comes from config (same file as training); learnable ancestor # embeddings come from the checkpoint. load_state_dict overwrites both # buffers (LUT, W_l) and parameters (ancestor_embeddings) to match training # exactly. ancestor_table = build_ancestor_table( config, device, embed_dim=config["model"]["hidden_size"], ) assert "ancestor_table" in ckpt, ( "Checkpoint has no 'ancestor_table' entry — cannot run hierarchical " "inference. Re-train with train_sad.py or use an older inference " "script that ignores ancestors." ) ancestor_table.load_state_dict(ckpt["ancestor_table"]) ancestor_table.to(device=device, dtype=dtype).eval() print(f"Loaded ancestor table: {ancestor_table.num_levels} ancestor level(s)") level_lambdas = None if args.level_lambdas: level_lambdas = [float(x) for x in args.level_lambdas.split(",")] sampler = BlockDiffusionSampler( model=_unwrap(model), ancestor_table=ancestor_table, tokenizer=tokenizer, device=device, dtype=dtype, level_lambdas=level_lambdas, leaf_temperature=args.leaf_temperature, ) print(f"level_lambdas (per ancestor level l=1..K) = " f"{sampler.level_lambdas[1:]}") print(f"leaf_temperature = {sampler.leaf_temperature}") # ── Optionally load a prompt from the training data ──────────────────── prompt_ids = None if args.mode == "conditional": data_cfg = config.get("data", {}) seq_len = config["model"]["max_seq_len"] block_size = config["model"]["block_size"] prompt_len = args.prompt_blocks * block_size assert prompt_len < seq_len, ( f"prompt_blocks * block_size = {prompt_len} must be < max_seq_len = {seq_len}" ) # Resolve relative cache_dir against the sad/ repo root (scripts/..), so # the script works regardless of cwd (training ran from sad/). cache_dir = data_cfg.get("cache_dir", None) if cache_dir is not None and not Path(cache_dir).is_absolute(): repo_root = ROOT candidate = repo_root / cache_dir if candidate.exists(): cache_dir = str(candidate) loader = build_owt_dataloader( tokenizer, split="train[:-100000]", seq_len=seq_len, batch_size=args.num_samples, num_workers=0, cache_dir=cache_dir, seed=args.data_seed, mode=data_cfg.get("mode", "subsample"), shard_across_ranks=False, ) batch = next(iter(loader)) prompt_ids = batch["input_ids"][:args.num_samples, :prompt_len].to(device) print(f"Loaded conditional prompt from training data: " f"shape={tuple(prompt_ids.shape)} (prompt_blocks={args.prompt_blocks})") print(f"Sampling {args.num_samples} sequences ({args.mode}) " f"length={config['model']['max_seq_len']}, " f"random positions_per_step={args.positions_per_step}") out = sampler.generate( batch_size=args.num_samples if prompt_ids is None else None, prompt_ids=prompt_ids, positions_per_step=args.positions_per_step, stop_on_eos=args.stop_on_eos, ) # ── Decode & print ───────────────────────────────────────────────────── P = out.get("prompt_len", 0) print("\n" + "=" * 72) for i, ids in enumerate(out["tokens"]): ids_list = ids.tolist() print(f"[Sample {i + 1}]") if P > 0: prompt_text = tokenizer.decode(ids_list[:P], skip_special_tokens=True) gen_text = tokenizer.decode(ids_list[P:], skip_special_tokens=True) print(f" {prompt_text}") print(f" {gen_text}") else: print(tokenizer.decode(ids_list, skip_special_tokens=True)) print() if __name__ == "__main__": main()