| |
| """ |
| compute_mauve.py β Generate samples from a SAD or AR checkpoint and compute MAUVE |
| against OpenWebText reference texts. |
| |
| Pipeline: |
| 1. Load checkpoint (block-diffusion SAD or autoregressive AR). |
| 2. Generate N unconditional samples (default 5000). |
| 3. Load OWT test split (last 100k docs), randomly pick N reference texts. |
| 4. Decode both sides to plain text. |
| 5. Compute MAUVE score using a local GPT-2 Large as the featurize model. |
| 6. Save results to JSON. |
| |
| Usage (block diffusion): |
| python eval/compute_mauve.py \ |
| --checkpoint outputs/sad/latest.pt \ |
| --model_type block \ |
| --num_samples 5000 \ |
| --sample_batch_size 64 \ |
| --positions_per_step 1 \ |
| --level_lambdas 1.0 \ |
| --leaf_temperature 1.0 \ |
| --run_id 1 |
| |
| Usage (AR baseline): |
| python eval/compute_mauve.py \ |
| --checkpoint outputs/ar_baseline/latest.pt \ |
| --model_type ar \ |
| --num_samples 5000 \ |
| --sample_batch_size 64 \ |
| --temperature 1.0 \ |
| --top_k 0 \ |
| --top_p 1.0 \ |
| --run_id 1 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import copy |
| import json |
| import random |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| sys.path.insert(0, str(ROOT)) |
| sys.path.insert(0, str(ROOT / "scripts")) |
|
|
| from src.data import build_owt_dataloader |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser( |
| description="Generate samples (block or AR) and compute MAUVE against OWT.", |
| ) |
| |
| p.add_argument("--model_type", type=str, default="sad", |
| choices=["sad", "block_diffusion", "ar"], |
| help="Model type: 'sad' for SAD block-diffusion, " |
| "'block_diffusion' for mask-only block diffusion, " |
| "'ar' for autoregressive baseline.") |
|
|
| |
| p.add_argument("--checkpoint", type=str, required=True, |
| help="Path to checkpoint (.pt).") |
| p.add_argument("--config", type=str, default=None, |
| help="Optional config path. If omitted, uses config " |
| "stored in the checkpoint.") |
|
|
| |
| p.add_argument("--positions_per_step", type=int, default=1, |
| help="(sad) Non-leaf positions advanced per denoising round.") |
| p.add_argument("--level_lambdas", type=str, default=None, |
| help="(sad) Comma-separated floats in [0,1], one per ancestor " |
| "level (e.g. '1.0,0.8'). Default: all 1.0.") |
| p.add_argument("--leaf_temperature", type=float, default=1.0, |
| help="(sad) Temperature on leaf logits before softmax.") |
|
|
| |
| p.add_argument("--temperature", type=float, default=1.0, |
| help="(ar) Sampling temperature. 0 means greedy.") |
| p.add_argument("--top_k", type=int, default=0, |
| help="(ar) Top-k sampling. 0 disables.") |
| p.add_argument("--top_p", type=float, default=1.0, |
| help="(ar) Top-p (nucleus) sampling. 1.0 disables.") |
| p.add_argument("--max_new_tokens", type=int, default=None, |
| help="(ar) Number of tokens to generate after BOS. " |
| "Default: max_seq_len - 1 from config.") |
| p.add_argument("--no_stop_on_eos", action="store_true", |
| help="(ar) Keep sampling until max_new_tokens is reached.") |
|
|
| |
| p.add_argument("--num_samples", type=int, default=5000, |
| help="Number of samples to generate AND reference texts " |
| "to draw from OWT. Default 5000.") |
| p.add_argument("--sample_batch_size", type=int, default=64, |
| help="Batch size for sampling.") |
|
|
| |
| p.add_argument("--ref_pool_size", type=int, default=100_000, |
| help="Size of the OWT test pool to draw from. " |
| "Default 100000 (last 100k docs).") |
|
|
| |
| p.add_argument("--featurize_model", type=str, default="models/gpt2-large", |
| help="Model name or local path for MAUVE featurization. " |
| "Passed to AutoModel.from_pretrained(). " |
| "Default: models/gpt2-large (local).") |
| p.add_argument("--mauve_max_text_length", type=int, default=512, |
| help="Max tokens for MAUVE featurization. Default 512.") |
| p.add_argument("--mauve_batch_size", type=int, default=8, |
| help="Batch size for MAUVE featurization. Default 8.") |
|
|
| |
| p.add_argument("--device", type=str, |
| default="cuda" if torch.cuda.is_available() else "cpu") |
| p.add_argument("--dtype", type=str, default="bf16", |
| choices=["bf16", "fp16", "fp32"]) |
| p.add_argument("--seed", type=int, default=42) |
|
|
| |
| p.add_argument("--output_dir", type=str, default="eval/mauve_results", |
| help="Directory to save MAUVE result JSONs when --output " |
| "is not provided.") |
| p.add_argument("--output", type=str, default=None, |
| help="Optional full path to the result JSON. If set, " |
| "overrides --output_dir and the default filename.") |
| p.add_argument("--save_samples", type=str, default=None, |
| help="Optional JSON path to save generated and reference " |
| "texts for later inspection or MAUVE recomputation.") |
| p.add_argument("--run_id", type=int, default=1, |
| help="Run identifier for distinguishing repeated runs. " |
| "Output file: mauve_{model_type}_run{run_id}.json") |
| return p.parse_args() |
|
|
|
|
| def resolve_path(raw: str) -> Path: |
| """Resolve relative paths against the sad/ repo root.""" |
| p = Path(raw) |
| if p.is_absolute(): |
| return p |
| return ROOT / p |
|
|
|
|
| |
| |
| |
|
|
| def load_sad_model(args, config, device, dtype): |
| from inference_sad import ( |
| BlockDiffusionSampler, |
| build_ancestor_table, |
| build_model, |
| build_tokenizer, |
| _unwrap, |
| ) |
|
|
| ckpt_path = resolve_path(args.checkpoint) |
| ckpt = torch.load(ckpt_path, map_location=device) |
|
|
| tokenizer = build_tokenizer(config) |
| model = build_model(config, device).to(dtype) |
| raw_state = ckpt.get("model", ckpt) |
| _unwrap(model).load_state_dict(raw_state, strict=False) |
| model.eval() |
| print(f"Loaded SAD checkpoint: {ckpt_path} (step={ckpt.get('step', '?')})") |
|
|
| ancestor_table = build_ancestor_table( |
| config, device, embed_dim=config["model"]["hidden_size"], |
| ) |
| assert "ancestor_table" in ckpt |
| ancestor_table.load_state_dict(ckpt["ancestor_table"]) |
| ancestor_table.to(device=device, dtype=dtype).eval() |
|
|
| level_lambdas = None |
| if args.level_lambdas: |
| level_lambdas = [float(x) for x in args.level_lambdas.split(",")] |
|
|
| sampler = BlockDiffusionSampler( |
| model=_unwrap(model), ancestor_table=ancestor_table, |
| tokenizer=tokenizer, device=device, dtype=dtype, |
| level_lambdas=level_lambdas, |
| leaf_temperature=args.leaf_temperature, |
| ) |
| print(f"level_lambdas = {sampler.level_lambdas[1:]}") |
| print(f"leaf_temperature = {sampler.leaf_temperature}") |
|
|
| return sampler, tokenizer, config |
|
|
|
|
| def load_block_diffusion_model(args, config, device, dtype): |
| from inference_block_diffusion import ( |
| BlockMaskDiffusionSampler, |
| build_model, |
| build_tokenizer, |
| _unwrap, |
| ) |
|
|
| ckpt_path = resolve_path(args.checkpoint) |
| ckpt = torch.load(ckpt_path, map_location=device) |
|
|
| tokenizer = build_tokenizer(config) |
| model = build_model(config, device).to(dtype) |
| raw_state = ckpt.get("model", ckpt) |
| _unwrap(model).load_state_dict(raw_state, strict=False) |
| model.eval() |
| print(f"Loaded block-diffusion checkpoint: {ckpt_path} (step={ckpt.get('step', '?')})") |
|
|
| sampler = BlockMaskDiffusionSampler( |
| model=_unwrap(model), |
| tokenizer=tokenizer, |
| device=device, |
| dtype=dtype, |
| leaf_temperature=args.leaf_temperature, |
| ) |
| print(f"leaf_temperature = {sampler.leaf_temperature}") |
|
|
| return sampler, tokenizer, config |
|
|
|
|
| @torch.no_grad() |
| def sample_block(sampler, num_samples, batch_size, positions_per_step=1): |
| chunks = [] |
| steps_list = [] |
| done = 0 |
| t0 = time.time() |
| while done < num_samples: |
| bs = min(batch_size, num_samples - done) |
| out = sampler.generate( |
| batch_size=bs, positions_per_step=positions_per_step, |
| ) |
| chunks.append(out["tokens"]) |
| steps_list.append(out["num_steps"]) |
| done += bs |
| print(f" sampled {done}/{num_samples} " |
| f"(steps this batch: {out['num_steps']})") |
| elapsed = time.time() - t0 |
| stats = { |
| "total_batches": len(steps_list), |
| "steps_per_batch": steps_list, |
| "avg_steps": sum(steps_list) / len(steps_list), |
| "min_steps": min(steps_list), |
| "max_steps": max(steps_list), |
| "sampling_time_sec": round(elapsed, 2), |
| "samples_per_sec": round(num_samples / elapsed, 2), |
| } |
| return torch.cat(chunks, dim=0), stats |
|
|
|
|
| |
| |
| |
|
|
| def load_ar_model(args, config, device, dtype): |
| from inference_ar import ( |
| ARSampler, |
| build_model, |
| build_tokenizer, |
| _unwrap, |
| ) |
|
|
| ckpt_path = resolve_path(args.checkpoint) |
| ckpt = torch.load(ckpt_path, map_location=device) |
|
|
| tokenizer = build_tokenizer(config) |
| model = build_model(config, device).to(dtype) |
| raw_state = ckpt.get("model", ckpt) |
| _unwrap(model).load_state_dict(raw_state, strict=False) |
| model.eval() |
| print(f"Loaded AR checkpoint: {ckpt_path} (step={ckpt.get('step', '?')})") |
|
|
| sampler = ARSampler( |
| model=_unwrap(model), tokenizer=tokenizer, |
| device=device, dtype=dtype, |
| ) |
|
|
| return sampler, tokenizer, config |
|
|
|
|
| @torch.no_grad() |
| def sample_ar(sampler, num_samples, batch_size, max_new_tokens, |
| temperature, top_k, top_p, stop_on_eos): |
| bos_token_id = sampler.tokenizer.bos_token_id |
| assert bos_token_id is not None, "tokenizer has no bos_token_id" |
|
|
| chunks = [] |
| done = 0 |
| t0 = time.time() |
| while done < num_samples: |
| bs = min(batch_size, num_samples - done) |
| prompt_ids = torch.full( |
| (bs, 1), bos_token_id, |
| dtype=torch.long, device=sampler.device, |
| ) |
| out = sampler.generate( |
| prompt_ids=prompt_ids, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| eos_token_id=sampler.tokenizer.eos_token_id, |
| stop_on_eos=stop_on_eos, |
| ) |
| chunks.append(out) |
| done += bs |
| print(f" sampled {done}/{num_samples}") |
| elapsed = time.time() - t0 |
| all_tokens = torch.cat(chunks, dim=0) |
| stats = { |
| "total_batches": len(chunks), |
| "avg_seq_len": round(all_tokens.shape[1], 2), |
| "sampling_time_sec": round(elapsed, 2), |
| "samples_per_sec": round(num_samples / elapsed, 2), |
| } |
| return all_tokens, stats |
|
|
|
|
| |
| |
| |
|
|
| def load_reference_texts(tokenizer, config, num_samples, pool_size, seed): |
| """Load OWT test split (last pool_size docs), pick num_samples randomly, |
| decode each doc as plain text (tokenize β BOS/EOS β decode back).""" |
| data_cfg = config.get("data", {}) |
| seq_len = config["model"]["max_seq_len"] |
|
|
| cache_dir = data_cfg.get("cache_dir", None) |
| if cache_dir is not None and not Path(cache_dir).is_absolute(): |
| repo_root = ROOT |
| candidate = repo_root / cache_dir |
| if candidate.exists(): |
| cache_dir = str(candidate) |
|
|
| |
| split = f"train[-{pool_size}:]" |
| print(f"Loading OWT reference pool: {split} ...") |
| loader = build_owt_dataloader( |
| tokenizer, |
| split=split, |
| seq_len=seq_len, |
| batch_size=min(num_samples, 512), |
| num_workers=4, |
| cache_dir=cache_dir, |
| seed=seed, |
| mode=data_cfg.get("mode", "subsample"), |
| shard_across_ranks=False, |
| ) |
|
|
| all_texts = [] |
| for batch in loader: |
| ids = batch["input_ids"] |
| mask = batch["attention_mask"] |
| for i in range(ids.shape[0]): |
| valid_len = mask[i].sum().item() |
| text = tokenizer.decode( |
| ids[i, :valid_len].tolist(), skip_special_tokens=True, |
| ) |
| if text.strip(): |
| all_texts.append(text) |
| if len(all_texts) >= pool_size: |
| break |
|
|
| print(f" collected {len(all_texts)} reference texts from OWT pool") |
|
|
| rng = random.Random(seed) |
| if len(all_texts) <= num_samples: |
| selected = all_texts |
| else: |
| selected = rng.sample(all_texts, num_samples) |
|
|
| print(f" selected {len(selected)} reference texts") |
| return selected |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| args = parse_args() |
| torch.manual_seed(args.seed) |
| random.seed(args.seed) |
| np.random.seed(args.seed) |
|
|
| device = torch.device(args.device) |
|
|
| from inference_sad import load_config, resolve_dtype |
| dtype = resolve_dtype(args.dtype) |
|
|
| |
| ckpt_path = resolve_path(args.checkpoint) |
| ckpt_peek = torch.load(ckpt_path, map_location="cpu") |
|
|
| if args.config is not None: |
| config = load_config(str(resolve_path(args.config))) |
| config_source = f"cli:{args.config}" |
| else: |
| assert "config" in ckpt_peek, ( |
| "--config not provided and checkpoint has no embedded 'config'." |
| ) |
| config = copy.deepcopy(ckpt_peek["config"]) |
| config_source = f"checkpoint:{args.checkpoint}" |
| del ckpt_peek |
| print(f"Config from {config_source}") |
| print(f"Model type: {args.model_type}") |
|
|
| |
| seq_len = config["model"]["max_seq_len"] |
|
|
| sampling_stats = {} |
|
|
| if args.model_type == "block": |
| sampler, tokenizer, config = load_block_model(args, config, device, dtype) |
|
|
| print(f"\nGenerating {args.num_samples} block-diffusion samples (L={seq_len}) ...") |
| tokens, sampling_stats = sample_block( |
| sampler, args.num_samples, args.sample_batch_size, |
| positions_per_step=args.positions_per_step, |
| ) |
| gen_texts = tokenizer.batch_decode( |
| tokens.tolist(), skip_special_tokens=True, |
| ) |
|
|
| del sampler |
| torch.cuda.empty_cache() |
|
|
| elif args.model_type == "block_mask": |
| sampler, tokenizer, config = load_block_mask_model(args, config, device, dtype) |
|
|
| print(f"\nGenerating {args.num_samples} block-mask samples (L={seq_len}) ...") |
| tokens, sampling_stats = sample_block( |
| sampler, args.num_samples, args.sample_batch_size, |
| positions_per_step=args.positions_per_step, |
| ) |
| gen_texts = tokenizer.batch_decode( |
| tokens.tolist(), skip_special_tokens=True, |
| ) |
|
|
| del sampler |
| torch.cuda.empty_cache() |
|
|
| elif args.model_type == "ar": |
| sampler, tokenizer, config = load_ar_model(args, config, device, dtype) |
|
|
| max_new_tokens = args.max_new_tokens |
| if max_new_tokens is None: |
| max_new_tokens = seq_len - 1 |
|
|
| print(f"\nGenerating {args.num_samples} AR samples " |
| f"(max_new_tokens={max_new_tokens}) ...") |
| tokens, sampling_stats = sample_ar( |
| sampler, args.num_samples, args.sample_batch_size, |
| max_new_tokens=max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| stop_on_eos=not args.no_stop_on_eos, |
| ) |
| gen_texts = tokenizer.batch_decode( |
| tokens.tolist(), skip_special_tokens=True, |
| ) |
|
|
| del sampler |
| torch.cuda.empty_cache() |
|
|
| print(f"First generated sample: {gen_texts[0][:120]!r}") |
|
|
| |
| print(f"\nLoading reference texts from OWT ...") |
| ref_texts = load_reference_texts( |
| tokenizer, config, args.num_samples, |
| args.ref_pool_size, args.seed, |
| ) |
|
|
| |
| import mauve |
|
|
| featurize_model = str(resolve_path(args.featurize_model)) |
| device_id = 0 if device.type == "cuda" else -1 |
|
|
| print(f"\nComputing MAUVE ...") |
| print(f" featurize model : {featurize_model}") |
| print(f" device_id : {device_id}") |
| print(f" max_text_length : {args.mauve_max_text_length}") |
| print(f" p (reference) : {len(ref_texts)} texts") |
| print(f" q (generated) : {len(gen_texts)} texts") |
|
|
| result = mauve.compute_mauve( |
| p_text=ref_texts, |
| q_text=gen_texts, |
| featurize_model_name=featurize_model, |
| device_id=device_id, |
| max_text_length=args.mauve_max_text_length, |
| batch_size=args.mauve_batch_size, |
| verbose=True, |
| seed=args.seed, |
| ) |
|
|
| mauve_score = result.mauve |
| print(f"\n{'='*60}") |
| print(f" MAUVE score: {mauve_score:.6f}") |
| print(f"{'='*60}") |
|
|
| |
| output = { |
| "mauve": mauve_score, |
| "frontier_integral": float(result.frontier_integral), |
| "run_id": args.run_id, |
| "model_type": args.model_type, |
| "checkpoint": str(ckpt_path), |
| "config_source": config_source, |
| "seq_len": seq_len, |
| "num_gen_samples": len(gen_texts), |
| "num_ref_samples": len(ref_texts), |
| "sample_batch_size": args.sample_batch_size, |
| "featurize_model": featurize_model, |
| "mauve_max_text_length": args.mauve_max_text_length, |
| "mauve_batch_size": args.mauve_batch_size, |
| "ref_pool_size": args.ref_pool_size, |
| "seed": args.seed, |
| "device": args.device, |
| "dtype": args.dtype, |
| } |
|
|
| |
| if args.model_type in {"sad", "block_diffusion"}: |
| output.update({ |
| "positions_per_step": args.positions_per_step, |
| "level_lambdas": None if args.model_type == "block_diffusion" else args.level_lambdas, |
| "leaf_temperature": args.leaf_temperature, |
| }) |
| elif args.model_type == "ar": |
| output.update({ |
| "temperature": args.temperature, |
| "top_k": args.top_k, |
| "top_p": args.top_p, |
| "max_new_tokens": max_new_tokens, |
| "stop_on_eos": not args.no_stop_on_eos, |
| }) |
|
|
| |
| output["sampling_stats"] = sampling_stats |
|
|
| if args.output is not None: |
| out_path = resolve_path(args.output) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| else: |
| out_dir = resolve_path(args.output_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| out_name = f"mauve_{args.model_type}_run{args.run_id}.json" |
| out_path = out_dir / out_name |
| with open(out_path, "w") as f: |
| json.dump(output, f, indent=2) |
| print(f"Saved results β {out_path}") |
|
|
| if args.save_samples is not None: |
| samples_path = resolve_path(args.save_samples) |
| samples_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(samples_path, "w") as f: |
| json.dump({ |
| "run_id": args.run_id, |
| "model_type": args.model_type, |
| "checkpoint": str(ckpt_path), |
| "seed": args.seed, |
| "generated_texts": gen_texts, |
| "reference_texts": ref_texts, |
| }, f, indent=2) |
| print(f"Saved samples β {samples_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|