sad / scripts /inference_block_diffusion.py
haochengsama's picture
Add files using upload-large-folder tool
8b0aeb2 verified
Raw
History Blame Contribute Delete
18.7 kB
#!/usr/bin/env python3
"""
inference_block_diffusion.py - Block-wise mask-diffusion sampling for SADModel.
This is the block-diffusion counterpart of inference_sad.py:
- no ancestor states / no lambda schedule
- each position is either MASK or LEAF
- within the current block, each round samples leaf tokens for every masked
position, then applies updates to `positions_per_step` random masked
positions per sample
Finalized earlier blocks are cached as K/V so later blocks only recompute the
current block, matching the left-to-right blockwise evaluation setup used by
the block-AR checkpoints.
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1] # sad/
from typing import Optional
import torch
import torch.nn.functional as F
import yaml
from einops import rearrange
sys.path.insert(0, str(ROOT))
from src.data import build_owt_dataloader
from src.models.dit_components import apply_rotary_pos_emb, modulate_fused
from src.models.sad_model import SADModel
class BlockMaskDiffusionSampler:
"""Block-wise mask-diffusion sampler with KV-cache reuse."""
def __init__(
self,
model: SADModel,
tokenizer,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
leaf_temperature: float = 1.0,
):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.dtype = dtype
self.leaf_temperature = float(leaf_temperature)
self.block_size: int = model.block_size
self.max_seq_len: int = model.max_seq_len
self.vocab_size: int = model.vocab_size
self.mask_id: int = tokenizer.mask_token_id
assert self.mask_id is not None, "tokenizer must have mask_token_id"
self.mask_level = 1
self.leaf_emb = model.get_leaf_embeddings().to(
device=device, dtype=dtype
).detach()
self.mask_emb = self.leaf_emb[self.mask_id]
def _build_mixed_embeddings(
self, level_ids: torch.Tensor, value_ids: torch.Tensor
) -> torch.Tensor:
"""Build [B, S, d] embeddings from {leaf, mask} states."""
B, S = level_ids.shape
d = self.leaf_emb.shape[-1]
embs = torch.empty(B, S, d, device=self.device, dtype=self.dtype)
leaf_mask = level_ids == 0
if leaf_mask.any():
embs[leaf_mask] = self.leaf_emb[value_ids[leaf_mask]]
mask_mask = level_ids == self.mask_level
if mask_mask.any():
embs[mask_mask] = self.mask_emb
return embs
def _run_layer_cached(
self,
layer_idx: int,
x: torch.Tensor,
rotary_cos_sin,
c: torch.Tensor,
k_prefix: Optional[torch.Tensor] = None,
v_prefix: Optional[torch.Tensor] = None,
):
layer = self.model.blocks[layer_idx]
B = x.shape[0]
H = layer.n_heads
dropout = layer.dropout
bds_fn = layer._bias_dropout_scale_fn()
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = layer.adaLN_modulation(c)[:, None].chunk(6, dim=2)
x_skip = x
x_normed = modulate_fused(layer.norm1(x), shift_msa, scale_msa)
qkv = layer.attn_qkv(x_normed)
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=H)
cos, sin = rotary_cos_sin
qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
q = qkv[:, :, 0].transpose(1, 2)
k_new = qkv[:, :, 1].transpose(1, 2)
v_new = qkv[:, :, 2].transpose(1, 2)
if k_prefix is not None:
k = torch.cat([k_prefix, k_new], dim=2)
v = torch.cat([v_prefix, v_new], dim=2)
else:
k = k_new
v = v_new
attn_out = F.scaled_dot_product_attention(q, k, v)
attn_out = rearrange(attn_out, "b h s d -> b s (h d)", b=B)
x = bds_fn(layer.attn_out(attn_out), None, gate_msa, x_skip, dropout)
x = bds_fn(
layer.mlp(modulate_fused(layer.norm2(x), shift_mlp, scale_mlp)),
None,
gate_mlp,
x,
dropout,
)
return x, k_new, v_new
def _forward_block_cached(
self,
level_ids_cur: torch.Tensor,
value_ids_cur: torch.Tensor,
block_idx: int,
kv_cache: list,
is_clean: bool = False,
):
model = self.model
B, bs = level_ids_cur.shape
block_start = block_idx * self.block_size
block_end = block_start + bs
device = self.device
embs = self._build_mixed_embeddings(level_ids_cur, value_ids_cur)
x = model.input_proj(embs)
block_idx_t = torch.full((bs,), block_idx, dtype=torch.long, device=device)
intra_pos = torch.arange(self.block_size, device=device)
seg_id = torch.full(
(bs,), 1 if is_clean else 0, dtype=torch.long, device=device
)
pos_emb = (
model.block_idx_embed(block_idx_t)
+ model.intra_pos_embed(intra_pos)
+ model.segment_embed(seg_id)
).unsqueeze(0).to(x.dtype)
x = x + pos_emb
c = model.cond_bias.unsqueeze(0).expand(B, -1).to(x.dtype)
position_ids = torch.arange(block_start, block_end, device=device)
rotary_cos_sin = model.rotary_emb(x, position_ids=position_ids)
new_kv = []
autocast_device = "cuda" if device.type == "cuda" else "cpu"
with torch.autocast(device_type=autocast_device, dtype=self.dtype):
for layer_idx in range(len(model.blocks)):
k_prefix, v_prefix = kv_cache[layer_idx]
x, k_cur, v_cur = self._run_layer_cached(
layer_idx,
x,
rotary_cos_sin,
c,
k_prefix=k_prefix,
v_prefix=v_prefix,
)
new_kv.append((k_cur, v_cur))
logits = model.output_layer(x, c)
logits = logits[..., : self.vocab_size]
logits[..., self.mask_id] = float("-inf")
return logits, new_kv
@staticmethod
def _append_kv(kv_cache: list, new_kv: list) -> list:
out = []
for (kp, vp), (kn, vn) in zip(kv_cache, new_kv):
if kp is None:
out.append((kn, vn))
else:
out.append((torch.cat([kp, kn], dim=2), torch.cat([vp, vn], dim=2)))
return out
@torch.no_grad()
def generate(
self,
batch_size: Optional[int] = None,
prompt_ids: Optional[torch.Tensor] = None,
positions_per_step: int = 1,
return_intermediate: bool = False,
stop_on_eos: bool = True,
) -> dict:
block_size = self.block_size
device = self.device
total_len = self.max_seq_len
assert total_len % block_size == 0, (
f"max_seq_len ({total_len}) must be divisible by block_size ({block_size})"
)
if prompt_ids is not None:
prompt_ids = prompt_ids.to(device=device, dtype=torch.long)
B, P = prompt_ids.shape
assert P % block_size == 0, (
f"prompt length P={P} must be a multiple of block_size={block_size}"
)
assert P < total_len, (
f"prompt length P={P} must be < total_len={total_len}"
)
start_block = P // block_size
else:
assert batch_size is not None, "batch_size or prompt_ids must be provided"
B = batch_size
P = 0
start_block = 0
level_ids = torch.full(
(B, total_len), self.mask_level, dtype=torch.long, device=device
)
value_ids = torch.zeros((B, total_len), dtype=torch.long, device=device)
if P > 0:
level_ids[:, :P] = 0
value_ids[:, :P] = prompt_ids
num_blocks = total_len // block_size
intermediate = [] if return_intermediate else None
finished = torch.zeros(B, dtype=torch.bool, device=device)
eos_id = getattr(self.tokenizer, "eos_token_id", None)
num_layers = len(self.model.blocks)
kv_cache = [(None, None) for _ in range(num_layers)]
for b in range(start_block):
bs0 = b * block_size
be0 = (b + 1) * block_size
_, new_kv = self._forward_block_cached(
level_ids[:, bs0:be0],
value_ids[:, bs0:be0],
b,
kv_cache,
is_clean=True,
)
kv_cache = self._append_kv(kv_cache, new_kv)
total_steps = 0
rounds_cap_per_block = block_size
for b in range(start_block, num_blocks):
block_start = b * block_size
block_end = (b + 1) * block_size
for _ in range(rounds_cap_per_block):
cur_level_block = level_ids[:, block_start:block_end]
non_leaf_block = cur_level_block > 0
if not non_leaf_block.any():
break
block_logits, _ = self._forward_block_cached(
level_ids[:, block_start:block_end],
value_ids[:, block_start:block_end],
b,
kv_cache,
)
logits_fp = block_logits.float()
if self.leaf_temperature != 1.0:
logits_fp = logits_fp / self.leaf_temperature
leaf_prob = F.softmax(logits_fp, dim=-1)
leaf_conf = leaf_prob.max(dim=-1).values
leaf_id = torch.multinomial(
leaf_prob.reshape(-1, leaf_prob.shape[-1]),
num_samples=1,
).squeeze(-1).reshape(B, block_size)
k = min(positions_per_step, block_size)
scores = torch.rand(B, block_size, device=device)
scores = torch.where(
non_leaf_block, scores, torch.full_like(scores, -1.0)
)
_, topk_idx = scores.topk(k, dim=-1)
selected = torch.zeros_like(non_leaf_block)
selected.scatter_(1, topk_idx, True)
apply_mask = selected & non_leaf_block
block_levels = level_ids[:, block_start:block_end]
block_values = value_ids[:, block_start:block_end]
level_ids[:, block_start:block_end] = torch.where(
apply_mask, torch.zeros_like(block_levels), block_levels
)
value_ids[:, block_start:block_end] = torch.where(
apply_mask, leaf_id, block_values
)
if return_intermediate:
intermediate.append(
(level_ids.clone().cpu(), value_ids.clone().cpu())
)
total_steps += 1
block_level = level_ids[:, block_start:block_end]
non_leaf = block_level > 0
if non_leaf.any():
block_logits, _ = self._forward_block_cached(
level_ids[:, block_start:block_end],
value_ids[:, block_start:block_end],
b,
kv_cache,
)
logits_fp = block_logits.float()
if self.leaf_temperature != 1.0:
logits_fp = logits_fp / self.leaf_temperature
leaf_prob = F.softmax(logits_fp, dim=-1)
leaf_id = torch.multinomial(
leaf_prob.reshape(-1, leaf_prob.shape[-1]),
num_samples=1,
).squeeze(-1).reshape(B, block_size)
level_ids[:, block_start:block_end] = torch.where(
non_leaf, torch.zeros_like(block_level), block_level
)
value_ids[:, block_start:block_end] = torch.where(
non_leaf, leaf_id, value_ids[:, block_start:block_end]
)
_, new_kv = self._forward_block_cached(
level_ids[:, block_start:block_end],
value_ids[:, block_start:block_end],
b,
kv_cache,
is_clean=True,
)
kv_cache = self._append_kv(kv_cache, new_kv)
if stop_on_eos and eos_id is not None:
block_vals = value_ids[:, block_start:block_end]
has_eos = block_vals.eq(eos_id).any(dim=-1)
finished = finished | has_eos
if finished.all():
break
result = {
"tokens": value_ids.cpu(),
"prompt_len": P,
"num_steps": total_steps,
}
if return_intermediate:
result["intermediate"] = intermediate
return result
def _unwrap(model):
while True:
if hasattr(model, "_orig_mod"):
model = model._orig_mod
elif hasattr(model, "module"):
model = model.module
else:
return model
def load_config(path: str) -> dict:
with open(path) as f:
return yaml.safe_load(f)
def build_tokenizer(config: dict):
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
ROOT / "tokenizers" / "gpt2",
local_files_only=True,
)
if tok.eos_token is None:
tok.add_special_tokens({"eos_token": "<|endoftext|>"})
if tok.bos_token is None:
tok.bos_token = tok.eos_token
if tok.pad_token is None:
tok.pad_token = tok.eos_token
if tok.mask_token_id is None:
tok.add_special_tokens({"mask_token": "[MASK]"})
config["model"]["vocab_size"] = len(tok)
if "level_sizes" in config["model"] and config["model"]["level_sizes"]:
config["model"]["level_sizes"][0] = len(tok)
return tok
def build_model(config: dict, device: torch.device) -> SADModel:
mc = config["model"]
return SADModel(
vocab_size=mc["vocab_size"],
hidden_size=mc["hidden_size"],
n_blocks=mc["n_blocks"],
n_heads=mc["n_heads"],
cond_dim=mc["cond_dim"],
max_seq_len=mc["max_seq_len"],
block_size=mc.get("block_size", 8),
dropout=mc.get("dropout", 0.0),
num_levels=mc.get("num_levels", 1),
level_sizes=mc.get("level_sizes"),
tie_weights=mc.get("tie_weights", False),
).to(device)
def resolve_dtype(name: str) -> torch.dtype:
return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name]
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--checkpoint", type=str, required=True)
p.add_argument("--config", type=str, default="configs/block_diffusion_owt_b32.yaml")
p.add_argument("--num_samples", type=int, default=1)
p.add_argument("--seed", type=int, default=42)
p.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
p.add_argument(
"--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]
)
p.add_argument("--stop_on_eos", action="store_true", default=True)
p.add_argument(
"--mode",
type=str,
default="unconditional",
choices=["unconditional", "conditional"],
)
p.add_argument("--prompt_blocks", type=int, default=1)
p.add_argument("--data_seed", type=int, default=0)
p.add_argument("--positions_per_step", type=int, default=1)
p.add_argument("--leaf_temperature", type=float, default=1.0)
return p.parse_args()
def main():
args = parse_args()
torch.manual_seed(args.seed)
device = torch.device(args.device)
dtype = resolve_dtype(args.dtype)
config = load_config(args.config)
tokenizer = build_tokenizer(config)
model = build_model(config, device).to(dtype)
ckpt = torch.load(args.checkpoint, map_location=device)
raw_state = ckpt.get("model", ckpt)
_unwrap(model).load_state_dict(raw_state, strict=False)
model.eval()
print(f"Loaded checkpoint: {args.checkpoint} (step={ckpt.get('step', '?')})")
sampler = BlockMaskDiffusionSampler(
model=_unwrap(model),
tokenizer=tokenizer,
device=device,
dtype=dtype,
leaf_temperature=args.leaf_temperature,
)
print(f"leaf_temperature = {sampler.leaf_temperature}")
prompt_ids = None
if args.mode == "conditional":
data_cfg = config.get("data", {})
seq_len = config["model"]["max_seq_len"]
block_size = config["model"]["block_size"]
prompt_len = args.prompt_blocks * block_size
assert prompt_len < seq_len, (
f"prompt_blocks * block_size = {prompt_len} must be < max_seq_len = {seq_len}"
)
cache_dir = data_cfg.get("cache_dir", None)
if cache_dir is not None and not Path(cache_dir).is_absolute():
repo_root = ROOT
candidate = repo_root / cache_dir
if candidate.exists():
cache_dir = str(candidate)
loader = build_owt_dataloader(
tokenizer,
split="train[:-100000]",
seq_len=seq_len,
batch_size=args.num_samples,
num_workers=0,
cache_dir=cache_dir,
seed=args.data_seed,
mode=data_cfg.get("mode", "subsample"),
shard_across_ranks=False,
)
batch = next(iter(loader))
prompt_ids = batch["input_ids"][: args.num_samples, :prompt_len].to(device)
print(
"Loaded conditional prompt from training data: "
f"shape={tuple(prompt_ids.shape)} (prompt_blocks={args.prompt_blocks})"
)
out = sampler.generate(
batch_size=args.num_samples if prompt_ids is None else None,
prompt_ids=prompt_ids,
positions_per_step=args.positions_per_step,
stop_on_eos=args.stop_on_eos,
)
prompt_len = out.get("prompt_len", 0)
print("\n" + "=" * 72)
for i, ids in enumerate(out["tokens"]):
ids_list = ids.tolist()
print(f"[Sample {i + 1}]")
if prompt_len > 0:
prompt_text = tokenizer.decode(ids_list[:prompt_len], skip_special_tokens=True)
gen_text = tokenizer.decode(ids_list[prompt_len:], skip_special_tokens=True)
print(f"<prompt ({prompt_len} tok)> {prompt_text}")
print(f"<generated> {gen_text}")
else:
print(tokenizer.decode(ids_list, skip_special_tokens=True))
print()
if __name__ == "__main__":
main()