#!/usr/bin/env python3 """ eval_ar_gen_ppl.py - Generative perplexity of AR baseline samples under an eval LM. Mirrors eval_gen_ppl.py for the autoregressive baseline: 1. Draw N unconditional samples from a trained AR checkpoint. 2. Decode them into text with the GPT-2 tokenizer used for training. 3. Score them under a pretrained eval LM (default: local gpt2-large). 4. Report avg_nll / ppl / acc and optionally save the samples. """ from __future__ import annotations import argparse import copy import json import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] # sad/ import numpy as np import torch import torch.nn.functional as F sys.path.insert(0, str(ROOT)) sys.path.insert(0, str(Path(__file__).parent)) # for `inference_ar` from inference_ar import ARSampler, build_model, build_tokenizer, load_config, resolve_dtype, _unwrap def resolve_path(raw: str | None) -> Path | None: if raw is None: return None path = Path(raw) if path.is_absolute(): return path return ROOT / path def parse_args(): p = argparse.ArgumentParser() p.add_argument("--checkpoint", type=str, required=True, help="AR checkpoint.") p.add_argument("--config", type=str, default=None, help="Optional config path. If omitted, uses the config " "stored inside --checkpoint.") p.add_argument("--num_samples", type=int, default=256, help="Total unconditional samples to generate.") p.add_argument("--sample_batch_size", type=int, default=16, help="Batch size for AR sampling.") p.add_argument("--eval_batch_size", type=int, default=8, help="Batch size when feeding samples to the eval LM.") p.add_argument("--eval_model_path", type=str, default="models/gpt2-large") p.add_argument("--eval_tokenizer_path", type=str, default=None, help="Defaults to --eval_model_path when omitted.") p.add_argument("--eval_max_length", type=int, default=1024, help="Truncation length for eval-LM tokenization.") p.add_argument("--max_new_tokens", type=int, default=511, help="Number of new tokens sampled after the BOS prompt.") p.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. 0 means greedy decoding.") p.add_argument("--top_k", type=int, default=0, help="0 disables top-k sampling.") p.add_argument("--top_p", type=float, default=1.0, help="1.0 disables top-p sampling.") p.add_argument("--no-stop-on-eos", action="store_true", help="Keep sampling until max_new_tokens is reached.") p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"], help="dtype for AR sampling (eval LM always runs fp32).") p.add_argument("--seed", type=int, default=42) p.add_argument("--output", type=str, default="outputs/ar_gen_ppl_metrics.json") p.add_argument("--save_samples", type=str, default=None, help="Optional path to dump decoded text samples (JSON).") return p.parse_args() @torch.no_grad() def sample_many( sampler: ARSampler, bos_token_id: int, num_samples: int, batch_size: int, max_new_tokens: int, temperature: float, top_k: int, top_p: float, stop_on_eos: bool, ) -> torch.Tensor: chunks = [] done = 0 while done < num_samples: bs = min(batch_size, num_samples - done) prompt_ids = torch.full( (bs, 1), bos_token_id, dtype=torch.long, device=sampler.device, ) out = sampler.generate( prompt_ids=prompt_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_token_id=sampler.tokenizer.eos_token_id, stop_on_eos=stop_on_eos, ) chunks.append(out) done += bs print(f" sampled {done}/{num_samples}") return torch.cat(chunks, dim=0) @torch.no_grad() def score_with_eval_lm( texts: list[str], eval_model, eval_tokenizer, device: torch.device, batch_size: int, max_length: int, ) -> dict: total_nll = 0.0 total_tokens = 0 total_acc = 0.0 all_nlls = [] for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] enc = eval_tokenizer( batch, padding=True, return_tensors="pt", truncation=True, max_length=max_length, ).to(device) input_ids = enc["input_ids"] attn_mask = enc["attention_mask"] outputs = eval_model( input_ids=input_ids, attention_mask=attn_mask, use_cache=False, return_dict=True, ) logits = outputs.logits[:, :-1] labels = input_ids[:, 1:] loss_mask = attn_mask[:, 1:] nll = F.cross_entropy( logits.transpose(-1, -2), labels, reduction="none", ) valid = loss_mask.bool() nll_valid = nll[valid] total_nll += nll_valid.sum().item() total_tokens += int(valid.sum().item()) all_nlls.extend(nll_valid.detach().cpu().tolist()) preds = logits.argmax(dim=-1) total_acc += ((preds == labels).float() * loss_mask).sum().item() print(f" scored {min(i + batch_size, len(texts))}/{len(texts)}") if total_tokens == 0: raise RuntimeError("No valid tokens scored - all samples were empty?") avg_nll = total_nll / total_tokens return { "avg_nll": avg_nll, "median_nll": float(np.median(all_nlls)), "ppl": float(np.exp(avg_nll)), "acc": total_acc / total_tokens, "tokens": total_tokens, } def main(): args = parse_args() torch.manual_seed(args.seed) device = torch.device(args.device) dtype = resolve_dtype(args.dtype) ckpt_path = resolve_path(args.checkpoint) if ckpt_path is None or not ckpt_path.exists(): raise FileNotFoundError(f"checkpoint not found: {args.checkpoint}") ckpt = torch.load(ckpt_path, map_location=device) if args.config is not None: config = load_config(str(resolve_path(args.config))) config_source = f"cli:{args.config}" else: assert "config" in ckpt, ( "--config was not provided and checkpoint has no embedded 'config' entry." ) config = copy.deepcopy(ckpt["config"]) config_source = f"checkpoint:{args.checkpoint}" print(f"Using config from {config_source}") tokenizer = build_tokenizer(config) bos_token_id = tokenizer.bos_token_id if bos_token_id is None: raise RuntimeError("tokenizer has no bos_token_id") model = build_model(config, device).to(dtype) raw_state = ckpt.get("model", ckpt) _unwrap(model).load_state_dict(raw_state, strict=False) model.eval() print(f"Loaded AR checkpoint: {ckpt_path} (step={ckpt.get('step', '?')})") sampler = ARSampler( model=_unwrap(model), tokenizer=tokenizer, device=device, dtype=dtype, ) total_seq_len = 1 + args.max_new_tokens print( f"Generating {args.num_samples} samples " f"(seq_len={total_seq_len}, temperature={args.temperature})..." ) tokens = sample_many( sampler=sampler, bos_token_id=bos_token_id, num_samples=args.num_samples, batch_size=args.sample_batch_size, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, stop_on_eos=not args.no_stop_on_eos, ) texts = tokenizer.batch_decode(tokens.tolist(), skip_special_tokens=True) print(f"First sample preview: {texts[0][:120]!r}") del sampler, model if device.type == "cuda": torch.cuda.empty_cache() from transformers import AutoModelForCausalLM, AutoTokenizer eval_model_path = resolve_path(args.eval_model_path) eval_tok_path = resolve_path(args.eval_tokenizer_path) or eval_model_path print(f"Loading eval LM: {eval_model_path}") print(f"Loading eval tokenizer: {eval_tok_path}") eval_tokenizer = AutoTokenizer.from_pretrained( str(eval_tok_path), local_files_only=True, ) if eval_tokenizer.pad_token is None: eval_tokenizer.pad_token = eval_tokenizer.eos_token eval_model = AutoModelForCausalLM.from_pretrained( str(eval_model_path), local_files_only=True, torch_dtype=torch.float32, ).to(device).eval() print(f"Eval LM loaded ({sum(p.numel() for p in eval_model.parameters()):,} params)") print("Scoring samples under eval LM...") metrics = score_with_eval_lm( texts, eval_model, eval_tokenizer, device, args.eval_batch_size, args.eval_max_length, ) metrics.update({ "checkpoint": str(ckpt_path), "eval_model": str(eval_model_path), "eval_tokenizer": str(eval_tok_path), "num_samples": len(texts), "generated_seq_len": int(tokens.shape[1]), "mode": "ar_generation", "temperature": args.temperature, "top_k": args.top_k, "top_p": args.top_p, "prompt_len": 1, "max_new_tokens": args.max_new_tokens, "stop_on_eos": not args.no_stop_on_eos, }) print(json.dumps(metrics, indent=2)) out_path = resolve_path(args.output) assert out_path is not None out_path.parent.mkdir(parents=True, exist_ok=True) with open(out_path, "w") as f: json.dump(metrics, f, indent=2) print(f"Saved metrics -> {out_path}") if args.save_samples: s_path = resolve_path(args.save_samples) assert s_path is not None s_path.parent.mkdir(parents=True, exist_ok=True) with open(s_path, "w") as f: json.dump({"samples": texts}, f, indent=2) print(f"Saved samples -> {s_path}") if __name__ == "__main__": main()