sad / scripts /compute_mauve.py
haochengsama's picture
Add files using upload-large-folder tool
8b0aeb2 verified
Raw
History Blame Contribute Delete
23.8 kB
#!/usr/bin/env python3
"""
compute_mauve.py – Generate samples from a SAD or AR checkpoint and compute MAUVE
against OpenWebText reference texts.
Pipeline:
1. Load checkpoint (block-diffusion SAD or autoregressive AR).
2. Generate N unconditional samples (default 5000).
3. Load OWT test split (last 100k docs), randomly pick N reference texts.
4. Decode both sides to plain text.
5. Compute MAUVE score using a local GPT-2 Large as the featurize model.
6. Save results to JSON.
Usage (block diffusion):
python eval/compute_mauve.py \
--checkpoint outputs/sad/latest.pt \
--model_type block \
--num_samples 5000 \
--sample_batch_size 64 \
--positions_per_step 1 \
--level_lambdas 1.0 \
--leaf_temperature 1.0 \
--run_id 1
Usage (AR baseline):
python eval/compute_mauve.py \
--checkpoint outputs/ar_baseline/latest.pt \
--model_type ar \
--num_samples 5000 \
--sample_batch_size 64 \
--temperature 1.0 \
--top_k 0 \
--top_p 1.0 \
--run_id 1
"""
from __future__ import annotations
import argparse
import copy
import json
import random
import sys
import time
from pathlib import Path
import numpy as np
import torch
ROOT = Path(__file__).resolve().parents[1] # sad/
sys.path.insert(0, str(ROOT)) # for `src.*`
sys.path.insert(0, str(ROOT / "scripts")) # for inference_sad / inference_ar
from src.data import build_owt_dataloader
def parse_args():
p = argparse.ArgumentParser(
description="Generate samples (block or AR) and compute MAUVE against OWT.",
)
# ── Model type ────────────────────────────────────────────────────────
p.add_argument("--model_type", type=str, default="sad",
choices=["sad", "block_diffusion", "ar"],
help="Model type: 'sad' for SAD block-diffusion, "
"'block_diffusion' for mask-only block diffusion, "
"'ar' for autoregressive baseline.")
# ── Model / checkpoint ────────────────────────────────────────────────
p.add_argument("--checkpoint", type=str, required=True,
help="Path to checkpoint (.pt).")
p.add_argument("--config", type=str, default=None,
help="Optional config path. If omitted, uses config "
"stored in the checkpoint.")
# ── Sampling hyper-parameters (SAD) ───────────────────────────────────
p.add_argument("--positions_per_step", type=int, default=1,
help="(sad) Non-leaf positions advanced per denoising round.")
p.add_argument("--level_lambdas", type=str, default=None,
help="(sad) Comma-separated floats in [0,1], one per ancestor "
"level (e.g. '1.0,0.8'). Default: all 1.0.")
p.add_argument("--leaf_temperature", type=float, default=1.0,
help="(sad) Temperature on leaf logits before softmax.")
# ── Sampling hyper-parameters (AR) ────────────────────────────────────
p.add_argument("--temperature", type=float, default=1.0,
help="(ar) Sampling temperature. 0 means greedy.")
p.add_argument("--top_k", type=int, default=0,
help="(ar) Top-k sampling. 0 disables.")
p.add_argument("--top_p", type=float, default=1.0,
help="(ar) Top-p (nucleus) sampling. 1.0 disables.")
p.add_argument("--max_new_tokens", type=int, default=None,
help="(ar) Number of tokens to generate after BOS. "
"Default: max_seq_len - 1 from config.")
p.add_argument("--no_stop_on_eos", action="store_true",
help="(ar) Keep sampling until max_new_tokens is reached.")
# ── Common sampling parameters ────────────────────────────────────────
p.add_argument("--num_samples", type=int, default=5000,
help="Number of samples to generate AND reference texts "
"to draw from OWT. Default 5000.")
p.add_argument("--sample_batch_size", type=int, default=64,
help="Batch size for sampling.")
# ── Reference data ────────────────────────────────────────────────────
p.add_argument("--ref_pool_size", type=int, default=100_000,
help="Size of the OWT test pool to draw from. "
"Default 100000 (last 100k docs).")
# ── MAUVE configuration ───────────────────────────────────────────────
p.add_argument("--featurize_model", type=str, default="models/gpt2-large",
help="Model name or local path for MAUVE featurization. "
"Passed to AutoModel.from_pretrained(). "
"Default: models/gpt2-large (local).")
p.add_argument("--mauve_max_text_length", type=int, default=512,
help="Max tokens for MAUVE featurization. Default 512.")
p.add_argument("--mauve_batch_size", type=int, default=8,
help="Batch size for MAUVE featurization. Default 8.")
# ── Device / precision ────────────────────────────────────────────────
p.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--dtype", type=str, default="bf16",
choices=["bf16", "fp16", "fp32"])
p.add_argument("--seed", type=int, default=42)
# ── Output ────────────────────────────────────────────────────────────
p.add_argument("--output_dir", type=str, default="eval/mauve_results",
help="Directory to save MAUVE result JSONs when --output "
"is not provided.")
p.add_argument("--output", type=str, default=None,
help="Optional full path to the result JSON. If set, "
"overrides --output_dir and the default filename.")
p.add_argument("--save_samples", type=str, default=None,
help="Optional JSON path to save generated and reference "
"texts for later inspection or MAUVE recomputation.")
p.add_argument("--run_id", type=int, default=1,
help="Run identifier for distinguishing repeated runs. "
"Output file: mauve_{model_type}_run{run_id}.json")
return p.parse_args()
def resolve_path(raw: str) -> Path:
"""Resolve relative paths against the sad/ repo root."""
p = Path(raw)
if p.is_absolute():
return p
return ROOT / p
# ─────────────────────────────────────────────────────────────────────────────
# Block diffusion sampling
# ─────────────────────────────────────────────────────────────────────────────
def load_sad_model(args, config, device, dtype):
from inference_sad import (
BlockDiffusionSampler,
build_ancestor_table,
build_model,
build_tokenizer,
_unwrap,
)
ckpt_path = resolve_path(args.checkpoint)
ckpt = torch.load(ckpt_path, map_location=device)
tokenizer = build_tokenizer(config)
model = build_model(config, device).to(dtype)
raw_state = ckpt.get("model", ckpt)
_unwrap(model).load_state_dict(raw_state, strict=False)
model.eval()
print(f"Loaded SAD checkpoint: {ckpt_path} (step={ckpt.get('step', '?')})")
ancestor_table = build_ancestor_table(
config, device, embed_dim=config["model"]["hidden_size"],
)
assert "ancestor_table" in ckpt
ancestor_table.load_state_dict(ckpt["ancestor_table"])
ancestor_table.to(device=device, dtype=dtype).eval()
level_lambdas = None
if args.level_lambdas:
level_lambdas = [float(x) for x in args.level_lambdas.split(",")]
sampler = BlockDiffusionSampler(
model=_unwrap(model), ancestor_table=ancestor_table,
tokenizer=tokenizer, device=device, dtype=dtype,
level_lambdas=level_lambdas,
leaf_temperature=args.leaf_temperature,
)
print(f"level_lambdas = {sampler.level_lambdas[1:]}")
print(f"leaf_temperature = {sampler.leaf_temperature}")
return sampler, tokenizer, config
def load_block_diffusion_model(args, config, device, dtype):
from inference_block_diffusion import (
BlockMaskDiffusionSampler,
build_model,
build_tokenizer,
_unwrap,
)
ckpt_path = resolve_path(args.checkpoint)
ckpt = torch.load(ckpt_path, map_location=device)
tokenizer = build_tokenizer(config)
model = build_model(config, device).to(dtype)
raw_state = ckpt.get("model", ckpt)
_unwrap(model).load_state_dict(raw_state, strict=False)
model.eval()
print(f"Loaded block-diffusion checkpoint: {ckpt_path} (step={ckpt.get('step', '?')})")
sampler = BlockMaskDiffusionSampler(
model=_unwrap(model),
tokenizer=tokenizer,
device=device,
dtype=dtype,
leaf_temperature=args.leaf_temperature,
)
print(f"leaf_temperature = {sampler.leaf_temperature}")
return sampler, tokenizer, config
@torch.no_grad()
def sample_block(sampler, num_samples, batch_size, positions_per_step=1):
chunks = []
steps_list = []
done = 0
t0 = time.time()
while done < num_samples:
bs = min(batch_size, num_samples - done)
out = sampler.generate(
batch_size=bs, positions_per_step=positions_per_step,
)
chunks.append(out["tokens"])
steps_list.append(out["num_steps"])
done += bs
print(f" sampled {done}/{num_samples} "
f"(steps this batch: {out['num_steps']})")
elapsed = time.time() - t0
stats = {
"total_batches": len(steps_list),
"steps_per_batch": steps_list,
"avg_steps": sum(steps_list) / len(steps_list),
"min_steps": min(steps_list),
"max_steps": max(steps_list),
"sampling_time_sec": round(elapsed, 2),
"samples_per_sec": round(num_samples / elapsed, 2),
}
return torch.cat(chunks, dim=0), stats
# ─────────────────────────────────────────────────────────────────────────────
# AR sampling
# ─────────────────────────────────────────────────────────────────────────────
def load_ar_model(args, config, device, dtype):
from inference_ar import (
ARSampler,
build_model,
build_tokenizer,
_unwrap,
)
ckpt_path = resolve_path(args.checkpoint)
ckpt = torch.load(ckpt_path, map_location=device)
tokenizer = build_tokenizer(config)
model = build_model(config, device).to(dtype)
raw_state = ckpt.get("model", ckpt)
_unwrap(model).load_state_dict(raw_state, strict=False)
model.eval()
print(f"Loaded AR checkpoint: {ckpt_path} (step={ckpt.get('step', '?')})")
sampler = ARSampler(
model=_unwrap(model), tokenizer=tokenizer,
device=device, dtype=dtype,
)
return sampler, tokenizer, config
@torch.no_grad()
def sample_ar(sampler, num_samples, batch_size, max_new_tokens,
temperature, top_k, top_p, stop_on_eos):
bos_token_id = sampler.tokenizer.bos_token_id
assert bos_token_id is not None, "tokenizer has no bos_token_id"
chunks = []
done = 0
t0 = time.time()
while done < num_samples:
bs = min(batch_size, num_samples - done)
prompt_ids = torch.full(
(bs, 1), bos_token_id,
dtype=torch.long, device=sampler.device,
)
out = sampler.generate(
prompt_ids=prompt_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=sampler.tokenizer.eos_token_id,
stop_on_eos=stop_on_eos,
)
chunks.append(out)
done += bs
print(f" sampled {done}/{num_samples}")
elapsed = time.time() - t0
all_tokens = torch.cat(chunks, dim=0)
stats = {
"total_batches": len(chunks),
"avg_seq_len": round(all_tokens.shape[1], 2),
"sampling_time_sec": round(elapsed, 2),
"samples_per_sec": round(num_samples / elapsed, 2),
}
return all_tokens, stats
# ─────────────────────────────────────────────────────────────────────────────
# Reference text loading
# ──────────────────��──────────────────────────────────────────────────────────
def load_reference_texts(tokenizer, config, num_samples, pool_size, seed):
"""Load OWT test split (last pool_size docs), pick num_samples randomly,
decode each doc as plain text (tokenize β†’ BOS/EOS β†’ decode back)."""
data_cfg = config.get("data", {})
seq_len = config["model"]["max_seq_len"]
cache_dir = data_cfg.get("cache_dir", None)
if cache_dir is not None and not Path(cache_dir).is_absolute():
repo_root = ROOT
candidate = repo_root / cache_dir
if candidate.exists():
cache_dir = str(candidate)
# OWT test split = last pool_size docs
split = f"train[-{pool_size}:]"
print(f"Loading OWT reference pool: {split} ...")
loader = build_owt_dataloader(
tokenizer,
split=split,
seq_len=seq_len,
batch_size=min(num_samples, 512),
num_workers=4,
cache_dir=cache_dir,
seed=seed,
mode=data_cfg.get("mode", "subsample"),
shard_across_ranks=False,
)
all_texts = []
for batch in loader:
ids = batch["input_ids"] # [B, seq_len]
mask = batch["attention_mask"] # [B, seq_len]
for i in range(ids.shape[0]):
valid_len = mask[i].sum().item()
text = tokenizer.decode(
ids[i, :valid_len].tolist(), skip_special_tokens=True,
)
if text.strip():
all_texts.append(text)
if len(all_texts) >= pool_size:
break
print(f" collected {len(all_texts)} reference texts from OWT pool")
rng = random.Random(seed)
if len(all_texts) <= num_samples:
selected = all_texts
else:
selected = rng.sample(all_texts, num_samples)
print(f" selected {len(selected)} reference texts")
return selected
# ─────────────────────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
device = torch.device(args.device)
from inference_sad import load_config, resolve_dtype
dtype = resolve_dtype(args.dtype)
# ── Load config ───────────────────────────────────────────────────────
ckpt_path = resolve_path(args.checkpoint)
ckpt_peek = torch.load(ckpt_path, map_location="cpu")
if args.config is not None:
config = load_config(str(resolve_path(args.config)))
config_source = f"cli:{args.config}"
else:
assert "config" in ckpt_peek, (
"--config not provided and checkpoint has no embedded 'config'."
)
config = copy.deepcopy(ckpt_peek["config"])
config_source = f"checkpoint:{args.checkpoint}"
del ckpt_peek
print(f"Config from {config_source}")
print(f"Model type: {args.model_type}")
# ── Load model & generate ─────────────────────────────────────────────
seq_len = config["model"]["max_seq_len"]
sampling_stats = {}
if args.model_type == "block":
sampler, tokenizer, config = load_block_model(args, config, device, dtype)
print(f"\nGenerating {args.num_samples} block-diffusion samples (L={seq_len}) ...")
tokens, sampling_stats = sample_block(
sampler, args.num_samples, args.sample_batch_size,
positions_per_step=args.positions_per_step,
)
gen_texts = tokenizer.batch_decode(
tokens.tolist(), skip_special_tokens=True,
)
del sampler
torch.cuda.empty_cache()
elif args.model_type == "block_mask":
sampler, tokenizer, config = load_block_mask_model(args, config, device, dtype)
print(f"\nGenerating {args.num_samples} block-mask samples (L={seq_len}) ...")
tokens, sampling_stats = sample_block(
sampler, args.num_samples, args.sample_batch_size,
positions_per_step=args.positions_per_step,
)
gen_texts = tokenizer.batch_decode(
tokens.tolist(), skip_special_tokens=True,
)
del sampler
torch.cuda.empty_cache()
elif args.model_type == "ar":
sampler, tokenizer, config = load_ar_model(args, config, device, dtype)
max_new_tokens = args.max_new_tokens
if max_new_tokens is None:
max_new_tokens = seq_len - 1 # BOS + max_new_tokens = seq_len
print(f"\nGenerating {args.num_samples} AR samples "
f"(max_new_tokens={max_new_tokens}) ...")
tokens, sampling_stats = sample_ar(
sampler, args.num_samples, args.sample_batch_size,
max_new_tokens=max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
stop_on_eos=not args.no_stop_on_eos,
)
gen_texts = tokenizer.batch_decode(
tokens.tolist(), skip_special_tokens=True,
)
del sampler
torch.cuda.empty_cache()
print(f"First generated sample: {gen_texts[0][:120]!r}")
# ── Load reference texts ──────────────────────────────────────────────
print(f"\nLoading reference texts from OWT ...")
ref_texts = load_reference_texts(
tokenizer, config, args.num_samples,
args.ref_pool_size, args.seed,
)
# ── Compute MAUVE ─────────────────────────────────────────────────────
import mauve
featurize_model = str(resolve_path(args.featurize_model))
device_id = 0 if device.type == "cuda" else -1
print(f"\nComputing MAUVE ...")
print(f" featurize model : {featurize_model}")
print(f" device_id : {device_id}")
print(f" max_text_length : {args.mauve_max_text_length}")
print(f" p (reference) : {len(ref_texts)} texts")
print(f" q (generated) : {len(gen_texts)} texts")
result = mauve.compute_mauve(
p_text=ref_texts,
q_text=gen_texts,
featurize_model_name=featurize_model,
device_id=device_id,
max_text_length=args.mauve_max_text_length,
batch_size=args.mauve_batch_size,
verbose=True,
seed=args.seed,
)
mauve_score = result.mauve
print(f"\n{'='*60}")
print(f" MAUVE score: {mauve_score:.6f}")
print(f"{'='*60}")
# ── Save results ──────────────────────────────────────────────────────
output = {
"mauve": mauve_score,
"frontier_integral": float(result.frontier_integral),
"run_id": args.run_id,
"model_type": args.model_type,
"checkpoint": str(ckpt_path),
"config_source": config_source,
"seq_len": seq_len,
"num_gen_samples": len(gen_texts),
"num_ref_samples": len(ref_texts),
"sample_batch_size": args.sample_batch_size,
"featurize_model": featurize_model,
"mauve_max_text_length": args.mauve_max_text_length,
"mauve_batch_size": args.mauve_batch_size,
"ref_pool_size": args.ref_pool_size,
"seed": args.seed,
"device": args.device,
"dtype": args.dtype,
}
# Model-type-specific fields
if args.model_type in {"sad", "block_diffusion"}:
output.update({
"positions_per_step": args.positions_per_step,
"level_lambdas": None if args.model_type == "block_diffusion" else args.level_lambdas,
"leaf_temperature": args.leaf_temperature,
})
elif args.model_type == "ar":
output.update({
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
"max_new_tokens": max_new_tokens,
"stop_on_eos": not args.no_stop_on_eos,
})
# Sampling statistics (avg_steps, timing, etc.)
output["sampling_stats"] = sampling_stats
if args.output is not None:
out_path = resolve_path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
else:
out_dir = resolve_path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_name = f"mauve_{args.model_type}_run{args.run_id}.json"
out_path = out_dir / out_name
with open(out_path, "w") as f:
json.dump(output, f, indent=2)
print(f"Saved results β†’ {out_path}")
if args.save_samples is not None:
samples_path = resolve_path(args.save_samples)
samples_path.parent.mkdir(parents=True, exist_ok=True)
with open(samples_path, "w") as f:
json.dump({
"run_id": args.run_id,
"model_type": args.model_type,
"checkpoint": str(ckpt_path),
"seed": args.seed,
"generated_texts": gen_texts,
"reference_texts": ref_texts,
}, f, indent=2)
print(f"Saved samples β†’ {samples_path}")
if __name__ == "__main__":
main()