#!/usr/bin/env python3 """ eval_gen_ppl.py – Generative perplexity of SAD samples under a pretrained LM. Mirrors the standard "gen_ppl" pipeline used by HDLM/MDLM/soft-mask: 1. Draw N unconditional samples from a trained SAD checkpoint (length = model.max_seq_len). 2. Decode them into text with the SAD tokenizer. 3. Feed the decoded text through a pretrained AR eval LM (default: local gpt2), compute standard next-token cross-entropy. 4. Report avg_nll, median_nll, ppl = exp(total_nll / total_tokens), acc. The metric measures how "natural" SAD samples look under the eval LM — it is NOT a model-intrinsic PPL (no ELBO, no test set). It is directly comparable to soft-mask's `val/gen_ppl` and HDLM's `eval/generative_ppl.py`. Usage: python scripts/eval_gen_ppl.py \\ --checkpoint outputs/sad/latest.pt \\ --config configs/sad_owt.yaml \\ --num_samples 256 \\ --sample_batch_size 16 \\ --eval_model_path models/gpt2 """ from __future__ import annotations import argparse import copy import json import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] # sad/ from typing import Any import numpy as np import torch import torch.nn.functional as F sys.path.insert(0, str(ROOT)) # for `src.*` sys.path.insert(0, str(Path(__file__).parent)) # for `inference_sad` from inference_sad import ( BlockDiffusionSampler, build_ancestor_table, build_model, build_tokenizer, load_config, resolve_dtype, _unwrap, ) # ───────────────────────────────────────────────────────────────────────────── # Hard-coded text input (edit this string and run without --input_text). # Takes priority over SAD sampling when non-empty. Set to "" to disable. # ───────────────────────────────────────────────────────────────────────────── INPUT_TEXT = "" def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model_type", type=str, default="sad", choices=["sad", "block_diffusion"], help="Generation backend. 'sad' expects an ancestor-table " "checkpoint; 'block_diffusion' expects the mask-only checkpoint.") p.add_argument("--checkpoint", type=str, default=None, help="SAD checkpoint. Required unless --input_text or " "--input_file is given (text-only scoring mode).") p.add_argument("--config", type=str, default=None, help="Optional config path. If omitted, uses the config " "stored inside --checkpoint.") p.add_argument("--input_text", type=str, default=None, help="Score this single string under the eval LM instead " "of running SAD sampling. Skips SAD model loading.") p.add_argument("--input_file", type=str, default=None, help="Path to a text file, one sentence per line; each " "non-empty line is scored as a separate sample. " "Mutually exclusive with --input_text.") p.add_argument("--num_samples", type=int, default=256, help="Total unconditional samples to generate.") p.add_argument("--sample_batch_size", type=int, default=16, help="Batch size for SAD sampling.") p.add_argument("--eval_batch_size", type=int, default=8, help="Batch size when feeding samples to the eval LM.") p.add_argument("--eval_model_path", type=str, default="models/gpt2-large", help="Path (relative to sad/ or absolute) to a local " "HF causal-LM checkpoint used as the PPL evaluator. " "Default expects `huggingface-cli download gpt2-large " "--local-dir models/gpt2-large` to have been run.") p.add_argument("--eval_tokenizer_path", type=str, default="models/gpt2-large", help="Path to the eval-LM's tokenizer. For HF-downloaded " "models, tokenizer files sit alongside weights, so " "this defaults to the same path as --eval_model_path.") p.add_argument("--eval_max_length", type=int, default=1024, help="Truncation length for eval-LM tokenization.") p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"], help="dtype for SAD sampling (eval LM always runs fp32).") p.add_argument("--seed", type=int, default=42) p.add_argument("--output", type=str, default="outputs/gen_ppl_metrics.json") p.add_argument("--save_samples", type=str, default=None, help="Optional path to dump decoded text samples (JSON).") p.add_argument("--level_lambdas", type=str, default=None, help="Comma-separated K floats in [0, 1], one per ancestor " "level l = 1..K (e.g. '1.0,0.8,0.5'). Multiplies the " "level's max-prob conf before the cross-level argmax. " "Default: all 1.0 (original behavior).") p.add_argument("--positions_per_step", type=int, default=1, help="Number of random non-leaf positions to advance per " "denoising round within a block. Larger → fewer " "denoising rounds but less sequential refinement.") p.add_argument("--leaf_temperature", type=float, default=1.0, help="Temperature applied to leaf logits before softmax. " "Values < 1.0 sharpen p_leaf, which is then used for " "both leaf multinomial sampling and ancestor projection. " "Default 1.0 (no sharpening).") return p.parse_args() # ───────────────────────────────────────────────────────────────────────────── # Sampling # ───────────────────────────────────────────────────────────────────────────── @torch.no_grad() def sample_many(sampler: Any, num_samples: int, batch_size: int, positions_per_step: int = 1): """Generate `num_samples` unconditional sequences in chunks. Returns (tokens [N, L], avg_steps_per_sample). A generate() call shares its round count across the whole batch (the per-block loop breaks only when every sample's block is leaf), so avg is weighted by batch size. """ chunks = [] total_steps_weighted = 0 done = 0 while done < num_samples: bs = min(batch_size, num_samples - done) out = sampler.generate( batch_size=bs, positions_per_step=positions_per_step, ) chunks.append(out["tokens"]) # [bs, L] total_steps_weighted += out["num_steps"] * bs done += bs print(f" sampled {done}/{num_samples} (steps this call: {out['num_steps']})") avg_steps = total_steps_weighted / done return torch.cat(chunks, dim=0), avg_steps # [N, L], float # ───────────────────────────────────────────────────────────────────────────── # Scoring with eval LM # ───────────────────────────────────────────────────────────────────────────── @torch.no_grad() def score_with_eval_lm( texts: list, eval_model, eval_tokenizer, device: torch.device, batch_size: int, max_length: int, ) -> dict: """Standard next-token CE under a pretrained AR eval LM.""" total_nll = 0.0 total_tokens = 0 total_acc = 0.0 all_nlls = [] for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] enc = eval_tokenizer( batch, padding=True, return_tensors="pt", truncation=True, max_length=max_length, ).to(device) input_ids = enc["input_ids"] # [B, L] attn_mask = enc["attention_mask"] # [B, L] outputs = eval_model( input_ids=input_ids, attention_mask=attn_mask, use_cache=False, return_dict=True, ) logits = outputs.logits[:, :-1] # [B, L-1, V] labels = input_ids[:, 1:] # [B, L-1] loss_mask = attn_mask[:, 1:] # [B, L-1] nll = F.cross_entropy( logits.transpose(-1, -2), labels, reduction="none", ) # [B, L-1] valid = loss_mask.bool() nll_valid = nll[valid] total_nll += nll_valid.sum().item() total_tokens += int(valid.sum().item()) all_nlls.extend(nll_valid.detach().cpu().tolist()) preds = logits.argmax(dim=-1) total_acc += ((preds == labels).float() * loss_mask).sum().item() print(f" scored {min(i + batch_size, len(texts))}/{len(texts)}") if total_tokens == 0: raise RuntimeError("No valid tokens scored — all samples were empty?") avg_nll = total_nll / total_tokens return { "avg_nll": avg_nll, "median_nll": float(np.median(all_nlls)), "ppl": float(np.exp(avg_nll)), "acc": total_acc / total_tokens, "tokens": total_tokens, } # ───────────────────────────────────────────────────────────────────────────── # main # ───────────────────────────────────────────────────────────────────────────── def main(): args = parse_args() torch.manual_seed(args.seed) device = torch.device(args.device) dtype = resolve_dtype(args.dtype) # Priority: CLI flags > file-level INPUT_TEXT constant > SAD sampling. hardcoded_text = INPUT_TEXT.strip() or None effective_input_text = args.input_text or hardcoded_text text_mode = bool(effective_input_text or args.input_file) assert not (args.input_text and args.input_file), ( "--input_text and --input_file are mutually exclusive." ) if text_mode: # ── Text-only scoring: skip SAD model loading + sampling. ─────── if effective_input_text is not None: texts = [effective_input_text] else: with open(args.input_file) as f: texts = [ln.rstrip("\n") for ln in f if ln.strip()] print(f"Scoring {len(texts)} input text(s) directly under the eval LM " f"(SAD sampling skipped).") tokens = None avg_steps = None else: # ── Load SAD model + ancestor table ───────────────────────────── assert args.checkpoint is not None, ( "--checkpoint is required unless --input_text/--input_file is set." ) ckpt = torch.load(args.checkpoint, map_location=device) if args.config is not None: config = load_config(args.config) config_source = f"cli:{args.config}" else: assert "config" in ckpt, ( "--config was not provided and checkpoint has no embedded " "'config' entry." ) config = copy.deepcopy(ckpt["config"]) config_source = f"checkpoint:{args.checkpoint}" print(f"Using config from {config_source}") if args.model_type == "sad": sad_tokenizer = build_tokenizer(config) model = build_model(config, device).to(dtype) raw_state = ckpt.get("model", ckpt) _unwrap(model).load_state_dict(raw_state, strict=False) model.eval() print(f"Loaded SAD checkpoint: {args.checkpoint} " f"(step={ckpt.get('step', '?')})") ancestor_table = build_ancestor_table( config, device, embed_dim=config["model"]["hidden_size"], ) assert "ancestor_table" in ckpt, ( "Checkpoint has no 'ancestor_table' entry." ) ancestor_table.load_state_dict(ckpt["ancestor_table"]) ancestor_table.to(device=device, dtype=dtype).eval() level_lambdas = None if args.level_lambdas: level_lambdas = [float(x) for x in args.level_lambdas.split(",")] sampler = BlockDiffusionSampler( model=_unwrap(model), ancestor_table=ancestor_table, tokenizer=sad_tokenizer, device=device, dtype=dtype, level_lambdas=level_lambdas, leaf_temperature=args.leaf_temperature, ) print(f"level_lambdas = {sampler.level_lambdas[1:]}") print(f"leaf_temperature = {sampler.leaf_temperature}") else: from inference_block_diffusion import ( BlockMaskDiffusionSampler, build_model as build_mask_model, build_tokenizer as build_mask_tokenizer, _unwrap as unwrap_mask, ) sad_tokenizer = build_mask_tokenizer(config) model = build_mask_model(config, device).to(dtype) raw_state = ckpt.get("model", ckpt) unwrap_mask(model).load_state_dict(raw_state, strict=False) model.eval() print(f"Loaded block-mask checkpoint: {args.checkpoint} " f"(step={ckpt.get('step', '?')})") sampler = BlockMaskDiffusionSampler( model=unwrap_mask(model), tokenizer=sad_tokenizer, device=device, dtype=dtype, leaf_temperature=args.leaf_temperature, ) ancestor_table = None print(f"leaf_temperature = {sampler.leaf_temperature}") # ── Generate N samples ────────────────────────────────────────── L = config["model"]["max_seq_len"] print(f"Generating {args.num_samples} samples (L={L})...") tokens, avg_steps = sample_many( sampler, args.num_samples, args.sample_batch_size, positions_per_step=args.positions_per_step, ) print(f"Average denoising rounds per sample: {avg_steps:.2f}") texts = sad_tokenizer.batch_decode( tokens.tolist(), skip_special_tokens=True, ) print(f"First sample preview: {texts[0][:120]!r}") # Free SAD-side GPU memory before loading the eval LM. del sampler, model if ancestor_table is not None: del ancestor_table torch.cuda.empty_cache() # ── Load eval LM ───────────────────────────────────────────────────── from transformers import AutoModelForCausalLM, AutoTokenizer eval_model_path = Path(args.eval_model_path) if not eval_model_path.is_absolute(): eval_model_path = ROOT / eval_model_path eval_tok_path = Path(args.eval_tokenizer_path) if not eval_tok_path.is_absolute(): eval_tok_path = ROOT / eval_tok_path print(f"Loading eval LM: {eval_model_path}") print(f"Loading eval tokenizer: {eval_tok_path}") eval_tokenizer = AutoTokenizer.from_pretrained( str(eval_tok_path), local_files_only=True, ) if eval_tokenizer.pad_token is None: eval_tokenizer.pad_token = eval_tokenizer.eos_token eval_model = AutoModelForCausalLM.from_pretrained( str(eval_model_path), local_files_only=True, torch_dtype=torch.float32, # match HDLM's stability choice ).to(device).eval() print(f"Eval LM loaded ({sum(p.numel() for p in eval_model.parameters()):,} params)") # ── Score ──────────────────────────────────────────────────────────── print("Scoring samples under eval LM...") metrics = score_with_eval_lm( texts, eval_model, eval_tokenizer, device, args.eval_batch_size, args.eval_max_length, ) metrics.update({ "checkpoint": args.checkpoint, "eval_model": str(eval_model_path), "eval_tokenizer": str(eval_tok_path), "num_samples": len(texts), "generated_seq_len": int(tokens.shape[1]) if tokens is not None else None, "mode": "text_input" if text_mode else ( "block_diffusion_generation" if args.model_type == "block_diffusion" else "sad_generation" ), "model_type": args.model_type, "level_lambdas": None if args.model_type == "block_diffusion" else args.level_lambdas, "avg_steps": avg_steps, "positions_per_step": args.positions_per_step, "leaf_temperature": args.leaf_temperature, }) print(json.dumps(metrics, indent=2)) out_path = Path(args.output) if not out_path.is_absolute(): out_path = ROOT / out_path out_path.parent.mkdir(parents=True, exist_ok=True) with open(out_path, "w") as f: json.dump(metrics, f, indent=2) print(f"Saved metrics → {out_path}") if args.save_samples: s_path = Path(args.save_samples) if not s_path.is_absolute(): s_path = ROOT / s_path s_path.parent.mkdir(parents=True, exist_ok=True) with open(s_path, "w") as f: json.dump({"samples": texts}, f, indent=2) print(f"Saved samples → {s_path}") if __name__ == "__main__": main()