| |
| """ |
| eval_gen_ppl.py β Generative perplexity of SAD samples under a pretrained LM. |
| |
| Mirrors the standard "gen_ppl" pipeline used by HDLM/MDLM/soft-mask: |
| |
| 1. Draw N unconditional samples from a trained SAD checkpoint |
| (length = model.max_seq_len). |
| 2. Decode them into text with the SAD tokenizer. |
| 3. Feed the decoded text through a pretrained AR eval LM (default: local |
| gpt2), compute standard next-token cross-entropy. |
| 4. Report avg_nll, median_nll, ppl = exp(total_nll / total_tokens), acc. |
| |
| The metric measures how "natural" SAD samples look under the eval LM β it is |
| NOT a model-intrinsic PPL (no ELBO, no test set). It is directly comparable |
| to soft-mask's `val/gen_ppl` and HDLM's `eval/generative_ppl.py`. |
| |
| Usage: |
| python scripts/eval_gen_ppl.py \\ |
| --checkpoint outputs/sad/latest.pt \\ |
| --config configs/sad_owt.yaml \\ |
| --num_samples 256 \\ |
| --sample_batch_size 16 \\ |
| --eval_model_path models/gpt2 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import copy |
| import json |
| import sys |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| sys.path.insert(0, str(ROOT)) |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from inference_sad import ( |
| BlockDiffusionSampler, |
| build_ancestor_table, |
| build_model, |
| build_tokenizer, |
| load_config, |
| resolve_dtype, |
| _unwrap, |
| ) |
|
|
|
|
| |
| |
| |
| |
| INPUT_TEXT = "" |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--model_type", type=str, default="sad", |
| choices=["sad", "block_diffusion"], |
| help="Generation backend. 'sad' expects an ancestor-table " |
| "checkpoint; 'block_diffusion' expects the mask-only checkpoint.") |
| p.add_argument("--checkpoint", type=str, default=None, |
| help="SAD checkpoint. Required unless --input_text or " |
| "--input_file is given (text-only scoring mode).") |
| p.add_argument("--config", type=str, default=None, |
| help="Optional config path. If omitted, uses the config " |
| "stored inside --checkpoint.") |
| p.add_argument("--input_text", type=str, default=None, |
| help="Score this single string under the eval LM instead " |
| "of running SAD sampling. Skips SAD model loading.") |
| p.add_argument("--input_file", type=str, default=None, |
| help="Path to a text file, one sentence per line; each " |
| "non-empty line is scored as a separate sample. " |
| "Mutually exclusive with --input_text.") |
| p.add_argument("--num_samples", type=int, default=256, |
| help="Total unconditional samples to generate.") |
| p.add_argument("--sample_batch_size", type=int, default=16, |
| help="Batch size for SAD sampling.") |
| p.add_argument("--eval_batch_size", type=int, default=8, |
| help="Batch size when feeding samples to the eval LM.") |
| p.add_argument("--eval_model_path", type=str, default="models/gpt2-large", |
| help="Path (relative to sad/ or absolute) to a local " |
| "HF causal-LM checkpoint used as the PPL evaluator. " |
| "Default expects `huggingface-cli download gpt2-large " |
| "--local-dir models/gpt2-large` to have been run.") |
| p.add_argument("--eval_tokenizer_path", type=str, default="models/gpt2-large", |
| help="Path to the eval-LM's tokenizer. For HF-downloaded " |
| "models, tokenizer files sit alongside weights, so " |
| "this defaults to the same path as --eval_model_path.") |
| p.add_argument("--eval_max_length", type=int, default=1024, |
| help="Truncation length for eval-LM tokenization.") |
| p.add_argument("--device", type=str, |
| default="cuda" if torch.cuda.is_available() else "cpu") |
| p.add_argument("--dtype", type=str, default="bf16", |
| choices=["bf16", "fp16", "fp32"], |
| help="dtype for SAD sampling (eval LM always runs fp32).") |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--output", type=str, default="outputs/gen_ppl_metrics.json") |
| p.add_argument("--save_samples", type=str, default=None, |
| help="Optional path to dump decoded text samples (JSON).") |
| p.add_argument("--level_lambdas", type=str, default=None, |
| help="Comma-separated K floats in [0, 1], one per ancestor " |
| "level l = 1..K (e.g. '1.0,0.8,0.5'). Multiplies the " |
| "level's max-prob conf before the cross-level argmax. " |
| "Default: all 1.0 (original behavior).") |
| p.add_argument("--positions_per_step", type=int, default=1, |
| help="Number of random non-leaf positions to advance per " |
| "denoising round within a block. Larger β fewer " |
| "denoising rounds but less sequential refinement.") |
| p.add_argument("--leaf_temperature", type=float, default=1.0, |
| help="Temperature applied to leaf logits before softmax. " |
| "Values < 1.0 sharpen p_leaf, which is then used for " |
| "both leaf multinomial sampling and ancestor projection. " |
| "Default 1.0 (no sharpening).") |
| return p.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def sample_many(sampler: Any, |
| num_samples: int, batch_size: int, |
| positions_per_step: int = 1): |
| """Generate `num_samples` unconditional sequences in chunks. |
| |
| Returns (tokens [N, L], avg_steps_per_sample). A generate() call shares |
| its round count across the whole batch (the per-block loop breaks only |
| when every sample's block is leaf), so avg is weighted by batch size. |
| """ |
| chunks = [] |
| total_steps_weighted = 0 |
| done = 0 |
| while done < num_samples: |
| bs = min(batch_size, num_samples - done) |
| out = sampler.generate( |
| batch_size=bs, positions_per_step=positions_per_step, |
| ) |
| chunks.append(out["tokens"]) |
| total_steps_weighted += out["num_steps"] * bs |
| done += bs |
| print(f" sampled {done}/{num_samples} (steps this call: {out['num_steps']})") |
| avg_steps = total_steps_weighted / done |
| return torch.cat(chunks, dim=0), avg_steps |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def score_with_eval_lm( |
| texts: list, |
| eval_model, |
| eval_tokenizer, |
| device: torch.device, |
| batch_size: int, |
| max_length: int, |
| ) -> dict: |
| """Standard next-token CE under a pretrained AR eval LM.""" |
| total_nll = 0.0 |
| total_tokens = 0 |
| total_acc = 0.0 |
| all_nlls = [] |
|
|
| for i in range(0, len(texts), batch_size): |
| batch = texts[i:i + batch_size] |
| enc = eval_tokenizer( |
| batch, padding=True, return_tensors="pt", |
| truncation=True, max_length=max_length, |
| ).to(device) |
| input_ids = enc["input_ids"] |
| attn_mask = enc["attention_mask"] |
|
|
| outputs = eval_model( |
| input_ids=input_ids, attention_mask=attn_mask, |
| use_cache=False, return_dict=True, |
| ) |
| logits = outputs.logits[:, :-1] |
| labels = input_ids[:, 1:] |
| loss_mask = attn_mask[:, 1:] |
|
|
| nll = F.cross_entropy( |
| logits.transpose(-1, -2), labels, reduction="none", |
| ) |
|
|
| valid = loss_mask.bool() |
| nll_valid = nll[valid] |
| total_nll += nll_valid.sum().item() |
| total_tokens += int(valid.sum().item()) |
| all_nlls.extend(nll_valid.detach().cpu().tolist()) |
|
|
| preds = logits.argmax(dim=-1) |
| total_acc += ((preds == labels).float() * loss_mask).sum().item() |
|
|
| print(f" scored {min(i + batch_size, len(texts))}/{len(texts)}") |
|
|
| if total_tokens == 0: |
| raise RuntimeError("No valid tokens scored β all samples were empty?") |
|
|
| avg_nll = total_nll / total_tokens |
| return { |
| "avg_nll": avg_nll, |
| "median_nll": float(np.median(all_nlls)), |
| "ppl": float(np.exp(avg_nll)), |
| "acc": total_acc / total_tokens, |
| "tokens": total_tokens, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| args = parse_args() |
| torch.manual_seed(args.seed) |
|
|
| device = torch.device(args.device) |
| dtype = resolve_dtype(args.dtype) |
|
|
| |
| hardcoded_text = INPUT_TEXT.strip() or None |
| effective_input_text = args.input_text or hardcoded_text |
| text_mode = bool(effective_input_text or args.input_file) |
| assert not (args.input_text and args.input_file), ( |
| "--input_text and --input_file are mutually exclusive." |
| ) |
|
|
| if text_mode: |
| |
| if effective_input_text is not None: |
| texts = [effective_input_text] |
| else: |
| with open(args.input_file) as f: |
| texts = [ln.rstrip("\n") for ln in f if ln.strip()] |
| print(f"Scoring {len(texts)} input text(s) directly under the eval LM " |
| f"(SAD sampling skipped).") |
| tokens = None |
| avg_steps = None |
| else: |
| |
| assert args.checkpoint is not None, ( |
| "--checkpoint is required unless --input_text/--input_file is set." |
| ) |
| ckpt = torch.load(args.checkpoint, map_location=device) |
| if args.config is not None: |
| config = load_config(args.config) |
| config_source = f"cli:{args.config}" |
| else: |
| assert "config" in ckpt, ( |
| "--config was not provided and checkpoint has no embedded " |
| "'config' entry." |
| ) |
| config = copy.deepcopy(ckpt["config"]) |
| config_source = f"checkpoint:{args.checkpoint}" |
| print(f"Using config from {config_source}") |
|
|
| if args.model_type == "sad": |
| sad_tokenizer = build_tokenizer(config) |
| model = build_model(config, device).to(dtype) |
| raw_state = ckpt.get("model", ckpt) |
| _unwrap(model).load_state_dict(raw_state, strict=False) |
| model.eval() |
| print(f"Loaded SAD checkpoint: {args.checkpoint} " |
| f"(step={ckpt.get('step', '?')})") |
|
|
| ancestor_table = build_ancestor_table( |
| config, device, embed_dim=config["model"]["hidden_size"], |
| ) |
| assert "ancestor_table" in ckpt, ( |
| "Checkpoint has no 'ancestor_table' entry." |
| ) |
| ancestor_table.load_state_dict(ckpt["ancestor_table"]) |
| ancestor_table.to(device=device, dtype=dtype).eval() |
|
|
| level_lambdas = None |
| if args.level_lambdas: |
| level_lambdas = [float(x) for x in args.level_lambdas.split(",")] |
|
|
| sampler = BlockDiffusionSampler( |
| model=_unwrap(model), ancestor_table=ancestor_table, |
| tokenizer=sad_tokenizer, device=device, dtype=dtype, |
| level_lambdas=level_lambdas, |
| leaf_temperature=args.leaf_temperature, |
| ) |
| print(f"level_lambdas = {sampler.level_lambdas[1:]}") |
| print(f"leaf_temperature = {sampler.leaf_temperature}") |
| else: |
| from inference_block_diffusion import ( |
| BlockMaskDiffusionSampler, |
| build_model as build_mask_model, |
| build_tokenizer as build_mask_tokenizer, |
| _unwrap as unwrap_mask, |
| ) |
|
|
| sad_tokenizer = build_mask_tokenizer(config) |
| model = build_mask_model(config, device).to(dtype) |
| raw_state = ckpt.get("model", ckpt) |
| unwrap_mask(model).load_state_dict(raw_state, strict=False) |
| model.eval() |
| print(f"Loaded block-mask checkpoint: {args.checkpoint} " |
| f"(step={ckpt.get('step', '?')})") |
|
|
| sampler = BlockMaskDiffusionSampler( |
| model=unwrap_mask(model), |
| tokenizer=sad_tokenizer, |
| device=device, |
| dtype=dtype, |
| leaf_temperature=args.leaf_temperature, |
| ) |
| ancestor_table = None |
| print(f"leaf_temperature = {sampler.leaf_temperature}") |
|
|
| |
| L = config["model"]["max_seq_len"] |
| print(f"Generating {args.num_samples} samples (L={L})...") |
| tokens, avg_steps = sample_many( |
| sampler, args.num_samples, args.sample_batch_size, |
| positions_per_step=args.positions_per_step, |
| ) |
| print(f"Average denoising rounds per sample: {avg_steps:.2f}") |
| texts = sad_tokenizer.batch_decode( |
| tokens.tolist(), skip_special_tokens=True, |
| ) |
| print(f"First sample preview: {texts[0][:120]!r}") |
|
|
| |
| del sampler, model |
| if ancestor_table is not None: |
| del ancestor_table |
| torch.cuda.empty_cache() |
|
|
| |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| eval_model_path = Path(args.eval_model_path) |
| if not eval_model_path.is_absolute(): |
| eval_model_path = ROOT / eval_model_path |
| eval_tok_path = Path(args.eval_tokenizer_path) |
| if not eval_tok_path.is_absolute(): |
| eval_tok_path = ROOT / eval_tok_path |
| print(f"Loading eval LM: {eval_model_path}") |
| print(f"Loading eval tokenizer: {eval_tok_path}") |
|
|
| eval_tokenizer = AutoTokenizer.from_pretrained( |
| str(eval_tok_path), local_files_only=True, |
| ) |
| if eval_tokenizer.pad_token is None: |
| eval_tokenizer.pad_token = eval_tokenizer.eos_token |
|
|
| eval_model = AutoModelForCausalLM.from_pretrained( |
| str(eval_model_path), local_files_only=True, |
| torch_dtype=torch.float32, |
| ).to(device).eval() |
| print(f"Eval LM loaded ({sum(p.numel() for p in eval_model.parameters()):,} params)") |
|
|
| |
| print("Scoring samples under eval LM...") |
| metrics = score_with_eval_lm( |
| texts, eval_model, eval_tokenizer, device, |
| args.eval_batch_size, args.eval_max_length, |
| ) |
| metrics.update({ |
| "checkpoint": args.checkpoint, |
| "eval_model": str(eval_model_path), |
| "eval_tokenizer": str(eval_tok_path), |
| "num_samples": len(texts), |
| "generated_seq_len": int(tokens.shape[1]) if tokens is not None else None, |
| "mode": "text_input" if text_mode else ( |
| "block_diffusion_generation" if args.model_type == "block_diffusion" else "sad_generation" |
| ), |
| "model_type": args.model_type, |
| "level_lambdas": None if args.model_type == "block_diffusion" else args.level_lambdas, |
| "avg_steps": avg_steps, |
| "positions_per_step": args.positions_per_step, |
| "leaf_temperature": args.leaf_temperature, |
| }) |
| print(json.dumps(metrics, indent=2)) |
|
|
| out_path = Path(args.output) |
| if not out_path.is_absolute(): |
| out_path = ROOT / out_path |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(out_path, "w") as f: |
| json.dump(metrics, f, indent=2) |
| print(f"Saved metrics β {out_path}") |
|
|
| if args.save_samples: |
| s_path = Path(args.save_samples) |
| if not s_path.is_absolute(): |
| s_path = ROOT / s_path |
| s_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(s_path, "w") as f: |
| json.dump({"samples": texts}, f, indent=2) |
| print(f"Saved samples β {s_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|