#!/usr/bin/env python3 """ inference_block_diffusion.py - Block-wise mask-diffusion sampling for SADModel. This is the block-diffusion counterpart of inference_sad.py: - no ancestor states / no lambda schedule - each position is either MASK or LEAF - within the current block, each round samples leaf tokens for every masked position, then applies updates to `positions_per_step` random masked positions per sample Finalized earlier blocks are cached as K/V so later blocks only recompute the current block, matching the left-to-right blockwise evaluation setup used by the block-AR checkpoints. """ from __future__ import annotations import argparse import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] # sad/ from typing import Optional import torch import torch.nn.functional as F import yaml from einops import rearrange sys.path.insert(0, str(ROOT)) from src.data import build_owt_dataloader from src.models.dit_components import apply_rotary_pos_emb, modulate_fused from src.models.sad_model import SADModel class BlockMaskDiffusionSampler: """Block-wise mask-diffusion sampler with KV-cache reuse.""" def __init__( self, model: SADModel, tokenizer, device: torch.device, dtype: torch.dtype = torch.bfloat16, leaf_temperature: float = 1.0, ): self.model = model self.tokenizer = tokenizer self.device = device self.dtype = dtype self.leaf_temperature = float(leaf_temperature) self.block_size: int = model.block_size self.max_seq_len: int = model.max_seq_len self.vocab_size: int = model.vocab_size self.mask_id: int = tokenizer.mask_token_id assert self.mask_id is not None, "tokenizer must have mask_token_id" self.mask_level = 1 self.leaf_emb = model.get_leaf_embeddings().to( device=device, dtype=dtype ).detach() self.mask_emb = self.leaf_emb[self.mask_id] def _build_mixed_embeddings( self, level_ids: torch.Tensor, value_ids: torch.Tensor ) -> torch.Tensor: """Build [B, S, d] embeddings from {leaf, mask} states.""" B, S = level_ids.shape d = self.leaf_emb.shape[-1] embs = torch.empty(B, S, d, device=self.device, dtype=self.dtype) leaf_mask = level_ids == 0 if leaf_mask.any(): embs[leaf_mask] = self.leaf_emb[value_ids[leaf_mask]] mask_mask = level_ids == self.mask_level if mask_mask.any(): embs[mask_mask] = self.mask_emb return embs def _run_layer_cached( self, layer_idx: int, x: torch.Tensor, rotary_cos_sin, c: torch.Tensor, k_prefix: Optional[torch.Tensor] = None, v_prefix: Optional[torch.Tensor] = None, ): layer = self.model.blocks[layer_idx] B = x.shape[0] H = layer.n_heads dropout = layer.dropout bds_fn = layer._bias_dropout_scale_fn() ( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, ) = layer.adaLN_modulation(c)[:, None].chunk(6, dim=2) x_skip = x x_normed = modulate_fused(layer.norm1(x), shift_msa, scale_msa) qkv = layer.attn_qkv(x_normed) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=H) cos, sin = rotary_cos_sin qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)) q = qkv[:, :, 0].transpose(1, 2) k_new = qkv[:, :, 1].transpose(1, 2) v_new = qkv[:, :, 2].transpose(1, 2) if k_prefix is not None: k = torch.cat([k_prefix, k_new], dim=2) v = torch.cat([v_prefix, v_new], dim=2) else: k = k_new v = v_new attn_out = F.scaled_dot_product_attention(q, k, v) attn_out = rearrange(attn_out, "b h s d -> b s (h d)", b=B) x = bds_fn(layer.attn_out(attn_out), None, gate_msa, x_skip, dropout) x = bds_fn( layer.mlp(modulate_fused(layer.norm2(x), shift_mlp, scale_mlp)), None, gate_mlp, x, dropout, ) return x, k_new, v_new def _forward_block_cached( self, level_ids_cur: torch.Tensor, value_ids_cur: torch.Tensor, block_idx: int, kv_cache: list, is_clean: bool = False, ): model = self.model B, bs = level_ids_cur.shape block_start = block_idx * self.block_size block_end = block_start + bs device = self.device embs = self._build_mixed_embeddings(level_ids_cur, value_ids_cur) x = model.input_proj(embs) block_idx_t = torch.full((bs,), block_idx, dtype=torch.long, device=device) intra_pos = torch.arange(self.block_size, device=device) seg_id = torch.full( (bs,), 1 if is_clean else 0, dtype=torch.long, device=device ) pos_emb = ( model.block_idx_embed(block_idx_t) + model.intra_pos_embed(intra_pos) + model.segment_embed(seg_id) ).unsqueeze(0).to(x.dtype) x = x + pos_emb c = model.cond_bias.unsqueeze(0).expand(B, -1).to(x.dtype) position_ids = torch.arange(block_start, block_end, device=device) rotary_cos_sin = model.rotary_emb(x, position_ids=position_ids) new_kv = [] autocast_device = "cuda" if device.type == "cuda" else "cpu" with torch.autocast(device_type=autocast_device, dtype=self.dtype): for layer_idx in range(len(model.blocks)): k_prefix, v_prefix = kv_cache[layer_idx] x, k_cur, v_cur = self._run_layer_cached( layer_idx, x, rotary_cos_sin, c, k_prefix=k_prefix, v_prefix=v_prefix, ) new_kv.append((k_cur, v_cur)) logits = model.output_layer(x, c) logits = logits[..., : self.vocab_size] logits[..., self.mask_id] = float("-inf") return logits, new_kv @staticmethod def _append_kv(kv_cache: list, new_kv: list) -> list: out = [] for (kp, vp), (kn, vn) in zip(kv_cache, new_kv): if kp is None: out.append((kn, vn)) else: out.append((torch.cat([kp, kn], dim=2), torch.cat([vp, vn], dim=2))) return out @torch.no_grad() def generate( self, batch_size: Optional[int] = None, prompt_ids: Optional[torch.Tensor] = None, positions_per_step: int = 1, return_intermediate: bool = False, stop_on_eos: bool = True, ) -> dict: block_size = self.block_size device = self.device total_len = self.max_seq_len assert total_len % block_size == 0, ( f"max_seq_len ({total_len}) must be divisible by block_size ({block_size})" ) if prompt_ids is not None: prompt_ids = prompt_ids.to(device=device, dtype=torch.long) B, P = prompt_ids.shape assert P % block_size == 0, ( f"prompt length P={P} must be a multiple of block_size={block_size}" ) assert P < total_len, ( f"prompt length P={P} must be < total_len={total_len}" ) start_block = P // block_size else: assert batch_size is not None, "batch_size or prompt_ids must be provided" B = batch_size P = 0 start_block = 0 level_ids = torch.full( (B, total_len), self.mask_level, dtype=torch.long, device=device ) value_ids = torch.zeros((B, total_len), dtype=torch.long, device=device) if P > 0: level_ids[:, :P] = 0 value_ids[:, :P] = prompt_ids num_blocks = total_len // block_size intermediate = [] if return_intermediate else None finished = torch.zeros(B, dtype=torch.bool, device=device) eos_id = getattr(self.tokenizer, "eos_token_id", None) num_layers = len(self.model.blocks) kv_cache = [(None, None) for _ in range(num_layers)] for b in range(start_block): bs0 = b * block_size be0 = (b + 1) * block_size _, new_kv = self._forward_block_cached( level_ids[:, bs0:be0], value_ids[:, bs0:be0], b, kv_cache, is_clean=True, ) kv_cache = self._append_kv(kv_cache, new_kv) total_steps = 0 rounds_cap_per_block = block_size for b in range(start_block, num_blocks): block_start = b * block_size block_end = (b + 1) * block_size for _ in range(rounds_cap_per_block): cur_level_block = level_ids[:, block_start:block_end] non_leaf_block = cur_level_block > 0 if not non_leaf_block.any(): break block_logits, _ = self._forward_block_cached( level_ids[:, block_start:block_end], value_ids[:, block_start:block_end], b, kv_cache, ) logits_fp = block_logits.float() if self.leaf_temperature != 1.0: logits_fp = logits_fp / self.leaf_temperature leaf_prob = F.softmax(logits_fp, dim=-1) leaf_conf = leaf_prob.max(dim=-1).values leaf_id = torch.multinomial( leaf_prob.reshape(-1, leaf_prob.shape[-1]), num_samples=1, ).squeeze(-1).reshape(B, block_size) k = min(positions_per_step, block_size) scores = torch.rand(B, block_size, device=device) scores = torch.where( non_leaf_block, scores, torch.full_like(scores, -1.0) ) _, topk_idx = scores.topk(k, dim=-1) selected = torch.zeros_like(non_leaf_block) selected.scatter_(1, topk_idx, True) apply_mask = selected & non_leaf_block block_levels = level_ids[:, block_start:block_end] block_values = value_ids[:, block_start:block_end] level_ids[:, block_start:block_end] = torch.where( apply_mask, torch.zeros_like(block_levels), block_levels ) value_ids[:, block_start:block_end] = torch.where( apply_mask, leaf_id, block_values ) if return_intermediate: intermediate.append( (level_ids.clone().cpu(), value_ids.clone().cpu()) ) total_steps += 1 block_level = level_ids[:, block_start:block_end] non_leaf = block_level > 0 if non_leaf.any(): block_logits, _ = self._forward_block_cached( level_ids[:, block_start:block_end], value_ids[:, block_start:block_end], b, kv_cache, ) logits_fp = block_logits.float() if self.leaf_temperature != 1.0: logits_fp = logits_fp / self.leaf_temperature leaf_prob = F.softmax(logits_fp, dim=-1) leaf_id = torch.multinomial( leaf_prob.reshape(-1, leaf_prob.shape[-1]), num_samples=1, ).squeeze(-1).reshape(B, block_size) level_ids[:, block_start:block_end] = torch.where( non_leaf, torch.zeros_like(block_level), block_level ) value_ids[:, block_start:block_end] = torch.where( non_leaf, leaf_id, value_ids[:, block_start:block_end] ) _, new_kv = self._forward_block_cached( level_ids[:, block_start:block_end], value_ids[:, block_start:block_end], b, kv_cache, is_clean=True, ) kv_cache = self._append_kv(kv_cache, new_kv) if stop_on_eos and eos_id is not None: block_vals = value_ids[:, block_start:block_end] has_eos = block_vals.eq(eos_id).any(dim=-1) finished = finished | has_eos if finished.all(): break result = { "tokens": value_ids.cpu(), "prompt_len": P, "num_steps": total_steps, } if return_intermediate: result["intermediate"] = intermediate return result def _unwrap(model): while True: if hasattr(model, "_orig_mod"): model = model._orig_mod elif hasattr(model, "module"): model = model.module else: return model def load_config(path: str) -> dict: with open(path) as f: return yaml.safe_load(f) def build_tokenizer(config: dict): from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained( ROOT / "tokenizers" / "gpt2", local_files_only=True, ) if tok.eos_token is None: tok.add_special_tokens({"eos_token": "<|endoftext|>"}) if tok.bos_token is None: tok.bos_token = tok.eos_token if tok.pad_token is None: tok.pad_token = tok.eos_token if tok.mask_token_id is None: tok.add_special_tokens({"mask_token": "[MASK]"}) config["model"]["vocab_size"] = len(tok) if "level_sizes" in config["model"] and config["model"]["level_sizes"]: config["model"]["level_sizes"][0] = len(tok) return tok def build_model(config: dict, device: torch.device) -> SADModel: mc = config["model"] return SADModel( vocab_size=mc["vocab_size"], hidden_size=mc["hidden_size"], n_blocks=mc["n_blocks"], n_heads=mc["n_heads"], cond_dim=mc["cond_dim"], max_seq_len=mc["max_seq_len"], block_size=mc.get("block_size", 8), dropout=mc.get("dropout", 0.0), num_levels=mc.get("num_levels", 1), level_sizes=mc.get("level_sizes"), tie_weights=mc.get("tie_weights", False), ).to(device) def resolve_dtype(name: str) -> torch.dtype: return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name] def parse_args(): p = argparse.ArgumentParser() p.add_argument("--checkpoint", type=str, required=True) p.add_argument("--config", type=str, default="configs/block_diffusion_owt_b32.yaml") p.add_argument("--num_samples", type=int, default=1) p.add_argument("--seed", type=int, default=42) p.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) p.add_argument( "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"] ) p.add_argument("--stop_on_eos", action="store_true", default=True) p.add_argument( "--mode", type=str, default="unconditional", choices=["unconditional", "conditional"], ) p.add_argument("--prompt_blocks", type=int, default=1) p.add_argument("--data_seed", type=int, default=0) p.add_argument("--positions_per_step", type=int, default=1) p.add_argument("--leaf_temperature", type=float, default=1.0) return p.parse_args() def main(): args = parse_args() torch.manual_seed(args.seed) device = torch.device(args.device) dtype = resolve_dtype(args.dtype) config = load_config(args.config) tokenizer = build_tokenizer(config) model = build_model(config, device).to(dtype) ckpt = torch.load(args.checkpoint, map_location=device) raw_state = ckpt.get("model", ckpt) _unwrap(model).load_state_dict(raw_state, strict=False) model.eval() print(f"Loaded checkpoint: {args.checkpoint} (step={ckpt.get('step', '?')})") sampler = BlockMaskDiffusionSampler( model=_unwrap(model), tokenizer=tokenizer, device=device, dtype=dtype, leaf_temperature=args.leaf_temperature, ) print(f"leaf_temperature = {sampler.leaf_temperature}") prompt_ids = None if args.mode == "conditional": data_cfg = config.get("data", {}) seq_len = config["model"]["max_seq_len"] block_size = config["model"]["block_size"] prompt_len = args.prompt_blocks * block_size assert prompt_len < seq_len, ( f"prompt_blocks * block_size = {prompt_len} must be < max_seq_len = {seq_len}" ) cache_dir = data_cfg.get("cache_dir", None) if cache_dir is not None and not Path(cache_dir).is_absolute(): repo_root = ROOT candidate = repo_root / cache_dir if candidate.exists(): cache_dir = str(candidate) loader = build_owt_dataloader( tokenizer, split="train[:-100000]", seq_len=seq_len, batch_size=args.num_samples, num_workers=0, cache_dir=cache_dir, seed=args.data_seed, mode=data_cfg.get("mode", "subsample"), shard_across_ranks=False, ) batch = next(iter(loader)) prompt_ids = batch["input_ids"][: args.num_samples, :prompt_len].to(device) print( "Loaded conditional prompt from training data: " f"shape={tuple(prompt_ids.shape)} (prompt_blocks={args.prompt_blocks})" ) out = sampler.generate( batch_size=args.num_samples if prompt_ids is None else None, prompt_ids=prompt_ids, positions_per_step=args.positions_per_step, stop_on_eos=args.stop_on_eos, ) prompt_len = out.get("prompt_len", 0) print("\n" + "=" * 72) for i, ids in enumerate(out["tokens"]): ids_list = ids.tolist() print(f"[Sample {i + 1}]") if prompt_len > 0: prompt_text = tokenizer.decode(ids_list[:prompt_len], skip_special_tokens=True) gen_text = tokenizer.decode(ids_list[prompt_len:], skip_special_tokens=True) print(f" {prompt_text}") print(f" {gen_text}") else: print(tokenizer.decode(ids_list, skip_special_tokens=True)) print() if __name__ == "__main__": main()