| |
| """ |
| inference_block_diffusion.py - Block-wise mask-diffusion sampling for SADModel. |
| |
| This is the block-diffusion counterpart of inference_sad.py: |
| - no ancestor states / no lambda schedule |
| - each position is either MASK or LEAF |
| - within the current block, each round samples leaf tokens for every masked |
| position, then applies updates to `positions_per_step` random masked |
| positions per sample |
| |
| Finalized earlier blocks are cached as K/V so later blocks only recompute the |
| current block, matching the left-to-right blockwise evaluation setup used by |
| the block-AR checkpoints. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| from typing import Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| import yaml |
| from einops import rearrange |
|
|
| sys.path.insert(0, str(ROOT)) |
|
|
| from src.data import build_owt_dataloader |
| from src.models.dit_components import apply_rotary_pos_emb, modulate_fused |
| from src.models.sad_model import SADModel |
|
|
|
|
| class BlockMaskDiffusionSampler: |
| """Block-wise mask-diffusion sampler with KV-cache reuse.""" |
|
|
| def __init__( |
| self, |
| model: SADModel, |
| tokenizer, |
| device: torch.device, |
| dtype: torch.dtype = torch.bfloat16, |
| leaf_temperature: float = 1.0, |
| ): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.device = device |
| self.dtype = dtype |
| self.leaf_temperature = float(leaf_temperature) |
|
|
| self.block_size: int = model.block_size |
| self.max_seq_len: int = model.max_seq_len |
| self.vocab_size: int = model.vocab_size |
| self.mask_id: int = tokenizer.mask_token_id |
| assert self.mask_id is not None, "tokenizer must have mask_token_id" |
|
|
| self.mask_level = 1 |
| self.leaf_emb = model.get_leaf_embeddings().to( |
| device=device, dtype=dtype |
| ).detach() |
| self.mask_emb = self.leaf_emb[self.mask_id] |
|
|
| def _build_mixed_embeddings( |
| self, level_ids: torch.Tensor, value_ids: torch.Tensor |
| ) -> torch.Tensor: |
| """Build [B, S, d] embeddings from {leaf, mask} states.""" |
| B, S = level_ids.shape |
| d = self.leaf_emb.shape[-1] |
| embs = torch.empty(B, S, d, device=self.device, dtype=self.dtype) |
|
|
| leaf_mask = level_ids == 0 |
| if leaf_mask.any(): |
| embs[leaf_mask] = self.leaf_emb[value_ids[leaf_mask]] |
|
|
| mask_mask = level_ids == self.mask_level |
| if mask_mask.any(): |
| embs[mask_mask] = self.mask_emb |
|
|
| return embs |
|
|
| def _run_layer_cached( |
| self, |
| layer_idx: int, |
| x: torch.Tensor, |
| rotary_cos_sin, |
| c: torch.Tensor, |
| k_prefix: Optional[torch.Tensor] = None, |
| v_prefix: Optional[torch.Tensor] = None, |
| ): |
| layer = self.model.blocks[layer_idx] |
| B = x.shape[0] |
| H = layer.n_heads |
| dropout = layer.dropout |
| bds_fn = layer._bias_dropout_scale_fn() |
|
|
| ( |
| shift_msa, |
| scale_msa, |
| gate_msa, |
| shift_mlp, |
| scale_mlp, |
| gate_mlp, |
| ) = layer.adaLN_modulation(c)[:, None].chunk(6, dim=2) |
|
|
| x_skip = x |
| x_normed = modulate_fused(layer.norm1(x), shift_msa, scale_msa) |
| qkv = layer.attn_qkv(x_normed) |
| qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=H) |
| cos, sin = rotary_cos_sin |
| qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)) |
|
|
| q = qkv[:, :, 0].transpose(1, 2) |
| k_new = qkv[:, :, 1].transpose(1, 2) |
| v_new = qkv[:, :, 2].transpose(1, 2) |
|
|
| if k_prefix is not None: |
| k = torch.cat([k_prefix, k_new], dim=2) |
| v = torch.cat([v_prefix, v_new], dim=2) |
| else: |
| k = k_new |
| v = v_new |
|
|
| attn_out = F.scaled_dot_product_attention(q, k, v) |
| attn_out = rearrange(attn_out, "b h s d -> b s (h d)", b=B) |
|
|
| x = bds_fn(layer.attn_out(attn_out), None, gate_msa, x_skip, dropout) |
| x = bds_fn( |
| layer.mlp(modulate_fused(layer.norm2(x), shift_mlp, scale_mlp)), |
| None, |
| gate_mlp, |
| x, |
| dropout, |
| ) |
| return x, k_new, v_new |
|
|
| def _forward_block_cached( |
| self, |
| level_ids_cur: torch.Tensor, |
| value_ids_cur: torch.Tensor, |
| block_idx: int, |
| kv_cache: list, |
| is_clean: bool = False, |
| ): |
| model = self.model |
| B, bs = level_ids_cur.shape |
| block_start = block_idx * self.block_size |
| block_end = block_start + bs |
| device = self.device |
|
|
| embs = self._build_mixed_embeddings(level_ids_cur, value_ids_cur) |
| x = model.input_proj(embs) |
|
|
| block_idx_t = torch.full((bs,), block_idx, dtype=torch.long, device=device) |
| intra_pos = torch.arange(self.block_size, device=device) |
| seg_id = torch.full( |
| (bs,), 1 if is_clean else 0, dtype=torch.long, device=device |
| ) |
| pos_emb = ( |
| model.block_idx_embed(block_idx_t) |
| + model.intra_pos_embed(intra_pos) |
| + model.segment_embed(seg_id) |
| ).unsqueeze(0).to(x.dtype) |
| x = x + pos_emb |
|
|
| c = model.cond_bias.unsqueeze(0).expand(B, -1).to(x.dtype) |
| position_ids = torch.arange(block_start, block_end, device=device) |
| rotary_cos_sin = model.rotary_emb(x, position_ids=position_ids) |
|
|
| new_kv = [] |
| autocast_device = "cuda" if device.type == "cuda" else "cpu" |
| with torch.autocast(device_type=autocast_device, dtype=self.dtype): |
| for layer_idx in range(len(model.blocks)): |
| k_prefix, v_prefix = kv_cache[layer_idx] |
| x, k_cur, v_cur = self._run_layer_cached( |
| layer_idx, |
| x, |
| rotary_cos_sin, |
| c, |
| k_prefix=k_prefix, |
| v_prefix=v_prefix, |
| ) |
| new_kv.append((k_cur, v_cur)) |
| logits = model.output_layer(x, c) |
|
|
| logits = logits[..., : self.vocab_size] |
| logits[..., self.mask_id] = float("-inf") |
| return logits, new_kv |
|
|
| @staticmethod |
| def _append_kv(kv_cache: list, new_kv: list) -> list: |
| out = [] |
| for (kp, vp), (kn, vn) in zip(kv_cache, new_kv): |
| if kp is None: |
| out.append((kn, vn)) |
| else: |
| out.append((torch.cat([kp, kn], dim=2), torch.cat([vp, vn], dim=2))) |
| return out |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| batch_size: Optional[int] = None, |
| prompt_ids: Optional[torch.Tensor] = None, |
| positions_per_step: int = 1, |
| return_intermediate: bool = False, |
| stop_on_eos: bool = True, |
| ) -> dict: |
| block_size = self.block_size |
| device = self.device |
| total_len = self.max_seq_len |
|
|
| assert total_len % block_size == 0, ( |
| f"max_seq_len ({total_len}) must be divisible by block_size ({block_size})" |
| ) |
|
|
| if prompt_ids is not None: |
| prompt_ids = prompt_ids.to(device=device, dtype=torch.long) |
| B, P = prompt_ids.shape |
| assert P % block_size == 0, ( |
| f"prompt length P={P} must be a multiple of block_size={block_size}" |
| ) |
| assert P < total_len, ( |
| f"prompt length P={P} must be < total_len={total_len}" |
| ) |
| start_block = P // block_size |
| else: |
| assert batch_size is not None, "batch_size or prompt_ids must be provided" |
| B = batch_size |
| P = 0 |
| start_block = 0 |
|
|
| level_ids = torch.full( |
| (B, total_len), self.mask_level, dtype=torch.long, device=device |
| ) |
| value_ids = torch.zeros((B, total_len), dtype=torch.long, device=device) |
| if P > 0: |
| level_ids[:, :P] = 0 |
| value_ids[:, :P] = prompt_ids |
|
|
| num_blocks = total_len // block_size |
| intermediate = [] if return_intermediate else None |
| finished = torch.zeros(B, dtype=torch.bool, device=device) |
| eos_id = getattr(self.tokenizer, "eos_token_id", None) |
|
|
| num_layers = len(self.model.blocks) |
| kv_cache = [(None, None) for _ in range(num_layers)] |
|
|
| for b in range(start_block): |
| bs0 = b * block_size |
| be0 = (b + 1) * block_size |
| _, new_kv = self._forward_block_cached( |
| level_ids[:, bs0:be0], |
| value_ids[:, bs0:be0], |
| b, |
| kv_cache, |
| is_clean=True, |
| ) |
| kv_cache = self._append_kv(kv_cache, new_kv) |
|
|
| total_steps = 0 |
| rounds_cap_per_block = block_size |
|
|
| for b in range(start_block, num_blocks): |
| block_start = b * block_size |
| block_end = (b + 1) * block_size |
|
|
| for _ in range(rounds_cap_per_block): |
| cur_level_block = level_ids[:, block_start:block_end] |
| non_leaf_block = cur_level_block > 0 |
| if not non_leaf_block.any(): |
| break |
|
|
| block_logits, _ = self._forward_block_cached( |
| level_ids[:, block_start:block_end], |
| value_ids[:, block_start:block_end], |
| b, |
| kv_cache, |
| ) |
| logits_fp = block_logits.float() |
| if self.leaf_temperature != 1.0: |
| logits_fp = logits_fp / self.leaf_temperature |
| leaf_prob = F.softmax(logits_fp, dim=-1) |
|
|
| leaf_conf = leaf_prob.max(dim=-1).values |
| leaf_id = torch.multinomial( |
| leaf_prob.reshape(-1, leaf_prob.shape[-1]), |
| num_samples=1, |
| ).squeeze(-1).reshape(B, block_size) |
|
|
| k = min(positions_per_step, block_size) |
| scores = torch.rand(B, block_size, device=device) |
| scores = torch.where( |
| non_leaf_block, scores, torch.full_like(scores, -1.0) |
| ) |
| _, topk_idx = scores.topk(k, dim=-1) |
| selected = torch.zeros_like(non_leaf_block) |
| selected.scatter_(1, topk_idx, True) |
| apply_mask = selected & non_leaf_block |
|
|
| block_levels = level_ids[:, block_start:block_end] |
| block_values = value_ids[:, block_start:block_end] |
| level_ids[:, block_start:block_end] = torch.where( |
| apply_mask, torch.zeros_like(block_levels), block_levels |
| ) |
| value_ids[:, block_start:block_end] = torch.where( |
| apply_mask, leaf_id, block_values |
| ) |
|
|
| if return_intermediate: |
| intermediate.append( |
| (level_ids.clone().cpu(), value_ids.clone().cpu()) |
| ) |
|
|
| total_steps += 1 |
|
|
| block_level = level_ids[:, block_start:block_end] |
| non_leaf = block_level > 0 |
| if non_leaf.any(): |
| block_logits, _ = self._forward_block_cached( |
| level_ids[:, block_start:block_end], |
| value_ids[:, block_start:block_end], |
| b, |
| kv_cache, |
| ) |
| logits_fp = block_logits.float() |
| if self.leaf_temperature != 1.0: |
| logits_fp = logits_fp / self.leaf_temperature |
| leaf_prob = F.softmax(logits_fp, dim=-1) |
| leaf_id = torch.multinomial( |
| leaf_prob.reshape(-1, leaf_prob.shape[-1]), |
| num_samples=1, |
| ).squeeze(-1).reshape(B, block_size) |
| level_ids[:, block_start:block_end] = torch.where( |
| non_leaf, torch.zeros_like(block_level), block_level |
| ) |
| value_ids[:, block_start:block_end] = torch.where( |
| non_leaf, leaf_id, value_ids[:, block_start:block_end] |
| ) |
|
|
| _, new_kv = self._forward_block_cached( |
| level_ids[:, block_start:block_end], |
| value_ids[:, block_start:block_end], |
| b, |
| kv_cache, |
| is_clean=True, |
| ) |
| kv_cache = self._append_kv(kv_cache, new_kv) |
|
|
| if stop_on_eos and eos_id is not None: |
| block_vals = value_ids[:, block_start:block_end] |
| has_eos = block_vals.eq(eos_id).any(dim=-1) |
| finished = finished | has_eos |
| if finished.all(): |
| break |
|
|
| result = { |
| "tokens": value_ids.cpu(), |
| "prompt_len": P, |
| "num_steps": total_steps, |
| } |
| if return_intermediate: |
| result["intermediate"] = intermediate |
| return result |
|
|
|
|
| def _unwrap(model): |
| while True: |
| if hasattr(model, "_orig_mod"): |
| model = model._orig_mod |
| elif hasattr(model, "module"): |
| model = model.module |
| else: |
| return model |
|
|
|
|
| def load_config(path: str) -> dict: |
| with open(path) as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def build_tokenizer(config: dict): |
| from transformers import AutoTokenizer |
|
|
| tok = AutoTokenizer.from_pretrained( |
| ROOT / "tokenizers" / "gpt2", |
| local_files_only=True, |
| ) |
| if tok.eos_token is None: |
| tok.add_special_tokens({"eos_token": "<|endoftext|>"}) |
| if tok.bos_token is None: |
| tok.bos_token = tok.eos_token |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| if tok.mask_token_id is None: |
| tok.add_special_tokens({"mask_token": "[MASK]"}) |
| config["model"]["vocab_size"] = len(tok) |
| if "level_sizes" in config["model"] and config["model"]["level_sizes"]: |
| config["model"]["level_sizes"][0] = len(tok) |
| return tok |
|
|
|
|
| def build_model(config: dict, device: torch.device) -> SADModel: |
| mc = config["model"] |
| return SADModel( |
| vocab_size=mc["vocab_size"], |
| hidden_size=mc["hidden_size"], |
| n_blocks=mc["n_blocks"], |
| n_heads=mc["n_heads"], |
| cond_dim=mc["cond_dim"], |
| max_seq_len=mc["max_seq_len"], |
| block_size=mc.get("block_size", 8), |
| dropout=mc.get("dropout", 0.0), |
| num_levels=mc.get("num_levels", 1), |
| level_sizes=mc.get("level_sizes"), |
| tie_weights=mc.get("tie_weights", False), |
| ).to(device) |
|
|
|
|
| def resolve_dtype(name: str) -> torch.dtype: |
| return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name] |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--checkpoint", type=str, required=True) |
| p.add_argument("--config", type=str, default="configs/block_diffusion_owt_b32.yaml") |
| p.add_argument("--num_samples", type=int, default=1) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument( |
| "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| p.add_argument( |
| "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"] |
| ) |
| p.add_argument("--stop_on_eos", action="store_true", default=True) |
| p.add_argument( |
| "--mode", |
| type=str, |
| default="unconditional", |
| choices=["unconditional", "conditional"], |
| ) |
| p.add_argument("--prompt_blocks", type=int, default=1) |
| p.add_argument("--data_seed", type=int, default=0) |
| p.add_argument("--positions_per_step", type=int, default=1) |
| p.add_argument("--leaf_temperature", type=float, default=1.0) |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| torch.manual_seed(args.seed) |
|
|
| device = torch.device(args.device) |
| dtype = resolve_dtype(args.dtype) |
|
|
| config = load_config(args.config) |
| tokenizer = build_tokenizer(config) |
|
|
| model = build_model(config, device).to(dtype) |
| ckpt = torch.load(args.checkpoint, map_location=device) |
| raw_state = ckpt.get("model", ckpt) |
| _unwrap(model).load_state_dict(raw_state, strict=False) |
| model.eval() |
| print(f"Loaded checkpoint: {args.checkpoint} (step={ckpt.get('step', '?')})") |
|
|
| sampler = BlockMaskDiffusionSampler( |
| model=_unwrap(model), |
| tokenizer=tokenizer, |
| device=device, |
| dtype=dtype, |
| leaf_temperature=args.leaf_temperature, |
| ) |
| print(f"leaf_temperature = {sampler.leaf_temperature}") |
|
|
| prompt_ids = None |
| if args.mode == "conditional": |
| data_cfg = config.get("data", {}) |
| seq_len = config["model"]["max_seq_len"] |
| block_size = config["model"]["block_size"] |
| prompt_len = args.prompt_blocks * block_size |
| assert prompt_len < seq_len, ( |
| f"prompt_blocks * block_size = {prompt_len} must be < max_seq_len = {seq_len}" |
| ) |
| cache_dir = data_cfg.get("cache_dir", None) |
| if cache_dir is not None and not Path(cache_dir).is_absolute(): |
| repo_root = ROOT |
| candidate = repo_root / cache_dir |
| if candidate.exists(): |
| cache_dir = str(candidate) |
| loader = build_owt_dataloader( |
| tokenizer, |
| split="train[:-100000]", |
| seq_len=seq_len, |
| batch_size=args.num_samples, |
| num_workers=0, |
| cache_dir=cache_dir, |
| seed=args.data_seed, |
| mode=data_cfg.get("mode", "subsample"), |
| shard_across_ranks=False, |
| ) |
| batch = next(iter(loader)) |
| prompt_ids = batch["input_ids"][: args.num_samples, :prompt_len].to(device) |
| print( |
| "Loaded conditional prompt from training data: " |
| f"shape={tuple(prompt_ids.shape)} (prompt_blocks={args.prompt_blocks})" |
| ) |
|
|
| out = sampler.generate( |
| batch_size=args.num_samples if prompt_ids is None else None, |
| prompt_ids=prompt_ids, |
| positions_per_step=args.positions_per_step, |
| stop_on_eos=args.stop_on_eos, |
| ) |
|
|
| prompt_len = out.get("prompt_len", 0) |
| print("\n" + "=" * 72) |
| for i, ids in enumerate(out["tokens"]): |
| ids_list = ids.tolist() |
| print(f"[Sample {i + 1}]") |
| if prompt_len > 0: |
| prompt_text = tokenizer.decode(ids_list[:prompt_len], skip_special_tokens=True) |
| gen_text = tokenizer.decode(ids_list[prompt_len:], skip_special_tokens=True) |
| print(f"<prompt ({prompt_len} tok)> {prompt_text}") |
| print(f"<generated> {gen_text}") |
| else: |
| print(tokenizer.decode(ids_list, skip_special_tokens=True)) |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|