sad / scripts /eval_gen_ppl.py
haochengsama's picture
Add files using upload-large-folder tool
8b0aeb2 verified
Raw
History Blame Contribute Delete
18.8 kB
#!/usr/bin/env python3
"""
eval_gen_ppl.py – Generative perplexity of SAD samples under a pretrained LM.
Mirrors the standard "gen_ppl" pipeline used by HDLM/MDLM/soft-mask:
1. Draw N unconditional samples from a trained SAD checkpoint
(length = model.max_seq_len).
2. Decode them into text with the SAD tokenizer.
3. Feed the decoded text through a pretrained AR eval LM (default: local
gpt2), compute standard next-token cross-entropy.
4. Report avg_nll, median_nll, ppl = exp(total_nll / total_tokens), acc.
The metric measures how "natural" SAD samples look under the eval LM β€” it is
NOT a model-intrinsic PPL (no ELBO, no test set). It is directly comparable
to soft-mask's `val/gen_ppl` and HDLM's `eval/generative_ppl.py`.
Usage:
python scripts/eval_gen_ppl.py \\
--checkpoint outputs/sad/latest.pt \\
--config configs/sad_owt.yaml \\
--num_samples 256 \\
--sample_batch_size 16 \\
--eval_model_path models/gpt2
"""
from __future__ import annotations
import argparse
import copy
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1] # sad/
from typing import Any
import numpy as np
import torch
import torch.nn.functional as F
sys.path.insert(0, str(ROOT)) # for `src.*`
sys.path.insert(0, str(Path(__file__).parent)) # for `inference_sad`
from inference_sad import (
BlockDiffusionSampler,
build_ancestor_table,
build_model,
build_tokenizer,
load_config,
resolve_dtype,
_unwrap,
)
# ─────────────────────────────────────────────────────────────────────────────
# Hard-coded text input (edit this string and run without --input_text).
# Takes priority over SAD sampling when non-empty. Set to "" to disable.
# ─────────────────────────────────────────────────────────────────────────────
INPUT_TEXT = ""
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model_type", type=str, default="sad",
choices=["sad", "block_diffusion"],
help="Generation backend. 'sad' expects an ancestor-table "
"checkpoint; 'block_diffusion' expects the mask-only checkpoint.")
p.add_argument("--checkpoint", type=str, default=None,
help="SAD checkpoint. Required unless --input_text or "
"--input_file is given (text-only scoring mode).")
p.add_argument("--config", type=str, default=None,
help="Optional config path. If omitted, uses the config "
"stored inside --checkpoint.")
p.add_argument("--input_text", type=str, default=None,
help="Score this single string under the eval LM instead "
"of running SAD sampling. Skips SAD model loading.")
p.add_argument("--input_file", type=str, default=None,
help="Path to a text file, one sentence per line; each "
"non-empty line is scored as a separate sample. "
"Mutually exclusive with --input_text.")
p.add_argument("--num_samples", type=int, default=256,
help="Total unconditional samples to generate.")
p.add_argument("--sample_batch_size", type=int, default=16,
help="Batch size for SAD sampling.")
p.add_argument("--eval_batch_size", type=int, default=8,
help="Batch size when feeding samples to the eval LM.")
p.add_argument("--eval_model_path", type=str, default="models/gpt2-large",
help="Path (relative to sad/ or absolute) to a local "
"HF causal-LM checkpoint used as the PPL evaluator. "
"Default expects `huggingface-cli download gpt2-large "
"--local-dir models/gpt2-large` to have been run.")
p.add_argument("--eval_tokenizer_path", type=str, default="models/gpt2-large",
help="Path to the eval-LM's tokenizer. For HF-downloaded "
"models, tokenizer files sit alongside weights, so "
"this defaults to the same path as --eval_model_path.")
p.add_argument("--eval_max_length", type=int, default=1024,
help="Truncation length for eval-LM tokenization.")
p.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--dtype", type=str, default="bf16",
choices=["bf16", "fp16", "fp32"],
help="dtype for SAD sampling (eval LM always runs fp32).")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--output", type=str, default="outputs/gen_ppl_metrics.json")
p.add_argument("--save_samples", type=str, default=None,
help="Optional path to dump decoded text samples (JSON).")
p.add_argument("--level_lambdas", type=str, default=None,
help="Comma-separated K floats in [0, 1], one per ancestor "
"level l = 1..K (e.g. '1.0,0.8,0.5'). Multiplies the "
"level's max-prob conf before the cross-level argmax. "
"Default: all 1.0 (original behavior).")
p.add_argument("--positions_per_step", type=int, default=1,
help="Number of random non-leaf positions to advance per "
"denoising round within a block. Larger β†’ fewer "
"denoising rounds but less sequential refinement.")
p.add_argument("--leaf_temperature", type=float, default=1.0,
help="Temperature applied to leaf logits before softmax. "
"Values < 1.0 sharpen p_leaf, which is then used for "
"both leaf multinomial sampling and ancestor projection. "
"Default 1.0 (no sharpening).")
return p.parse_args()
# ─────────────────────────────────────────────────────────────────────────────
# Sampling
# ─────────────────────────────────────────────────────────────────────────────
@torch.no_grad()
def sample_many(sampler: Any,
num_samples: int, batch_size: int,
positions_per_step: int = 1):
"""Generate `num_samples` unconditional sequences in chunks.
Returns (tokens [N, L], avg_steps_per_sample). A generate() call shares
its round count across the whole batch (the per-block loop breaks only
when every sample's block is leaf), so avg is weighted by batch size.
"""
chunks = []
total_steps_weighted = 0
done = 0
while done < num_samples:
bs = min(batch_size, num_samples - done)
out = sampler.generate(
batch_size=bs, positions_per_step=positions_per_step,
)
chunks.append(out["tokens"]) # [bs, L]
total_steps_weighted += out["num_steps"] * bs
done += bs
print(f" sampled {done}/{num_samples} (steps this call: {out['num_steps']})")
avg_steps = total_steps_weighted / done
return torch.cat(chunks, dim=0), avg_steps # [N, L], float
# ─────────────────────────────────────────────────────────────────────────────
# Scoring with eval LM
# ─────────────────────────────────────────────────────────────────────────────
@torch.no_grad()
def score_with_eval_lm(
texts: list,
eval_model,
eval_tokenizer,
device: torch.device,
batch_size: int,
max_length: int,
) -> dict:
"""Standard next-token CE under a pretrained AR eval LM."""
total_nll = 0.0
total_tokens = 0
total_acc = 0.0
all_nlls = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
enc = eval_tokenizer(
batch, padding=True, return_tensors="pt",
truncation=True, max_length=max_length,
).to(device)
input_ids = enc["input_ids"] # [B, L]
attn_mask = enc["attention_mask"] # [B, L]
outputs = eval_model(
input_ids=input_ids, attention_mask=attn_mask,
use_cache=False, return_dict=True,
)
logits = outputs.logits[:, :-1] # [B, L-1, V]
labels = input_ids[:, 1:] # [B, L-1]
loss_mask = attn_mask[:, 1:] # [B, L-1]
nll = F.cross_entropy(
logits.transpose(-1, -2), labels, reduction="none",
) # [B, L-1]
valid = loss_mask.bool()
nll_valid = nll[valid]
total_nll += nll_valid.sum().item()
total_tokens += int(valid.sum().item())
all_nlls.extend(nll_valid.detach().cpu().tolist())
preds = logits.argmax(dim=-1)
total_acc += ((preds == labels).float() * loss_mask).sum().item()
print(f" scored {min(i + batch_size, len(texts))}/{len(texts)}")
if total_tokens == 0:
raise RuntimeError("No valid tokens scored β€” all samples were empty?")
avg_nll = total_nll / total_tokens
return {
"avg_nll": avg_nll,
"median_nll": float(np.median(all_nlls)),
"ppl": float(np.exp(avg_nll)),
"acc": total_acc / total_tokens,
"tokens": total_tokens,
}
# ─────────────────────────────────────────────────────────────────────────────
# main
# ─────────────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
torch.manual_seed(args.seed)
device = torch.device(args.device)
dtype = resolve_dtype(args.dtype)
# Priority: CLI flags > file-level INPUT_TEXT constant > SAD sampling.
hardcoded_text = INPUT_TEXT.strip() or None
effective_input_text = args.input_text or hardcoded_text
text_mode = bool(effective_input_text or args.input_file)
assert not (args.input_text and args.input_file), (
"--input_text and --input_file are mutually exclusive."
)
if text_mode:
# ── Text-only scoring: skip SAD model loading + sampling. ───────
if effective_input_text is not None:
texts = [effective_input_text]
else:
with open(args.input_file) as f:
texts = [ln.rstrip("\n") for ln in f if ln.strip()]
print(f"Scoring {len(texts)} input text(s) directly under the eval LM "
f"(SAD sampling skipped).")
tokens = None
avg_steps = None
else:
# ── Load SAD model + ancestor table ─────────────────────────────
assert args.checkpoint is not None, (
"--checkpoint is required unless --input_text/--input_file is set."
)
ckpt = torch.load(args.checkpoint, map_location=device)
if args.config is not None:
config = load_config(args.config)
config_source = f"cli:{args.config}"
else:
assert "config" in ckpt, (
"--config was not provided and checkpoint has no embedded "
"'config' entry."
)
config = copy.deepcopy(ckpt["config"])
config_source = f"checkpoint:{args.checkpoint}"
print(f"Using config from {config_source}")
if args.model_type == "sad":
sad_tokenizer = build_tokenizer(config)
model = build_model(config, device).to(dtype)
raw_state = ckpt.get("model", ckpt)
_unwrap(model).load_state_dict(raw_state, strict=False)
model.eval()
print(f"Loaded SAD checkpoint: {args.checkpoint} "
f"(step={ckpt.get('step', '?')})")
ancestor_table = build_ancestor_table(
config, device, embed_dim=config["model"]["hidden_size"],
)
assert "ancestor_table" in ckpt, (
"Checkpoint has no 'ancestor_table' entry."
)
ancestor_table.load_state_dict(ckpt["ancestor_table"])
ancestor_table.to(device=device, dtype=dtype).eval()
level_lambdas = None
if args.level_lambdas:
level_lambdas = [float(x) for x in args.level_lambdas.split(",")]
sampler = BlockDiffusionSampler(
model=_unwrap(model), ancestor_table=ancestor_table,
tokenizer=sad_tokenizer, device=device, dtype=dtype,
level_lambdas=level_lambdas,
leaf_temperature=args.leaf_temperature,
)
print(f"level_lambdas = {sampler.level_lambdas[1:]}")
print(f"leaf_temperature = {sampler.leaf_temperature}")
else:
from inference_block_diffusion import (
BlockMaskDiffusionSampler,
build_model as build_mask_model,
build_tokenizer as build_mask_tokenizer,
_unwrap as unwrap_mask,
)
sad_tokenizer = build_mask_tokenizer(config)
model = build_mask_model(config, device).to(dtype)
raw_state = ckpt.get("model", ckpt)
unwrap_mask(model).load_state_dict(raw_state, strict=False)
model.eval()
print(f"Loaded block-mask checkpoint: {args.checkpoint} "
f"(step={ckpt.get('step', '?')})")
sampler = BlockMaskDiffusionSampler(
model=unwrap_mask(model),
tokenizer=sad_tokenizer,
device=device,
dtype=dtype,
leaf_temperature=args.leaf_temperature,
)
ancestor_table = None
print(f"leaf_temperature = {sampler.leaf_temperature}")
# ── Generate N samples ──────────────────────────────────────────
L = config["model"]["max_seq_len"]
print(f"Generating {args.num_samples} samples (L={L})...")
tokens, avg_steps = sample_many(
sampler, args.num_samples, args.sample_batch_size,
positions_per_step=args.positions_per_step,
)
print(f"Average denoising rounds per sample: {avg_steps:.2f}")
texts = sad_tokenizer.batch_decode(
tokens.tolist(), skip_special_tokens=True,
)
print(f"First sample preview: {texts[0][:120]!r}")
# Free SAD-side GPU memory before loading the eval LM.
del sampler, model
if ancestor_table is not None:
del ancestor_table
torch.cuda.empty_cache()
# ── Load eval LM ─────────────────────────────────────────────────────
from transformers import AutoModelForCausalLM, AutoTokenizer
eval_model_path = Path(args.eval_model_path)
if not eval_model_path.is_absolute():
eval_model_path = ROOT / eval_model_path
eval_tok_path = Path(args.eval_tokenizer_path)
if not eval_tok_path.is_absolute():
eval_tok_path = ROOT / eval_tok_path
print(f"Loading eval LM: {eval_model_path}")
print(f"Loading eval tokenizer: {eval_tok_path}")
eval_tokenizer = AutoTokenizer.from_pretrained(
str(eval_tok_path), local_files_only=True,
)
if eval_tokenizer.pad_token is None:
eval_tokenizer.pad_token = eval_tokenizer.eos_token
eval_model = AutoModelForCausalLM.from_pretrained(
str(eval_model_path), local_files_only=True,
torch_dtype=torch.float32, # match HDLM's stability choice
).to(device).eval()
print(f"Eval LM loaded ({sum(p.numel() for p in eval_model.parameters()):,} params)")
# ── Score ────────────────────────────────────────────────────────────
print("Scoring samples under eval LM...")
metrics = score_with_eval_lm(
texts, eval_model, eval_tokenizer, device,
args.eval_batch_size, args.eval_max_length,
)
metrics.update({
"checkpoint": args.checkpoint,
"eval_model": str(eval_model_path),
"eval_tokenizer": str(eval_tok_path),
"num_samples": len(texts),
"generated_seq_len": int(tokens.shape[1]) if tokens is not None else None,
"mode": "text_input" if text_mode else (
"block_diffusion_generation" if args.model_type == "block_diffusion" else "sad_generation"
),
"model_type": args.model_type,
"level_lambdas": None if args.model_type == "block_diffusion" else args.level_lambdas,
"avg_steps": avg_steps,
"positions_per_step": args.positions_per_step,
"leaf_temperature": args.leaf_temperature,
})
print(json.dumps(metrics, indent=2))
out_path = Path(args.output)
if not out_path.is_absolute():
out_path = ROOT / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "w") as f:
json.dump(metrics, f, indent=2)
print(f"Saved metrics β†’ {out_path}")
if args.save_samples:
s_path = Path(args.save_samples)
if not s_path.is_absolute():
s_path = ROOT / s_path
s_path.parent.mkdir(parents=True, exist_ok=True)
with open(s_path, "w") as f:
json.dump({"samples": texts}, f, indent=2)
print(f"Saved samples β†’ {s_path}")
if __name__ == "__main__":
main()