| |
| """ |
| eval_ar_gen_ppl.py - Generative perplexity of AR baseline samples under an eval LM. |
| |
| Mirrors eval_gen_ppl.py for the autoregressive baseline: |
| |
| 1. Draw N unconditional samples from a trained AR checkpoint. |
| 2. Decode them into text with the GPT-2 tokenizer used for training. |
| 3. Score them under a pretrained eval LM (default: local gpt2-large). |
| 4. Report avg_nll / ppl / acc and optionally save the samples. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import copy |
| import json |
| import sys |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| sys.path.insert(0, str(ROOT)) |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from inference_ar import ARSampler, build_model, build_tokenizer, load_config, resolve_dtype, _unwrap |
|
|
|
|
| def resolve_path(raw: str | None) -> Path | None: |
| if raw is None: |
| return None |
| path = Path(raw) |
| if path.is_absolute(): |
| return path |
| return ROOT / path |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--checkpoint", type=str, required=True, |
| help="AR checkpoint.") |
| p.add_argument("--config", type=str, default=None, |
| help="Optional config path. If omitted, uses the config " |
| "stored inside --checkpoint.") |
| p.add_argument("--num_samples", type=int, default=256, |
| help="Total unconditional samples to generate.") |
| p.add_argument("--sample_batch_size", type=int, default=16, |
| help="Batch size for AR sampling.") |
| p.add_argument("--eval_batch_size", type=int, default=8, |
| help="Batch size when feeding samples to the eval LM.") |
| p.add_argument("--eval_model_path", type=str, default="models/gpt2-large") |
| p.add_argument("--eval_tokenizer_path", type=str, default=None, |
| help="Defaults to --eval_model_path when omitted.") |
| p.add_argument("--eval_max_length", type=int, default=1024, |
| help="Truncation length for eval-LM tokenization.") |
| p.add_argument("--max_new_tokens", type=int, default=511, |
| help="Number of new tokens sampled after the BOS prompt.") |
| p.add_argument("--temperature", type=float, default=1.0, |
| help="Sampling temperature. 0 means greedy decoding.") |
| p.add_argument("--top_k", type=int, default=0, |
| help="0 disables top-k sampling.") |
| p.add_argument("--top_p", type=float, default=1.0, |
| help="1.0 disables top-p sampling.") |
| p.add_argument("--no-stop-on-eos", action="store_true", |
| help="Keep sampling until max_new_tokens is reached.") |
| p.add_argument("--device", type=str, |
| default="cuda" if torch.cuda.is_available() else "cpu") |
| p.add_argument("--dtype", type=str, default="bf16", |
| choices=["bf16", "fp16", "fp32"], |
| help="dtype for AR sampling (eval LM always runs fp32).") |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--output", type=str, default="outputs/ar_gen_ppl_metrics.json") |
| p.add_argument("--save_samples", type=str, default=None, |
| help="Optional path to dump decoded text samples (JSON).") |
| return p.parse_args() |
|
|
|
|
| @torch.no_grad() |
| def sample_many( |
| sampler: ARSampler, |
| bos_token_id: int, |
| num_samples: int, |
| batch_size: int, |
| max_new_tokens: int, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| stop_on_eos: bool, |
| ) -> torch.Tensor: |
| chunks = [] |
| done = 0 |
| while done < num_samples: |
| bs = min(batch_size, num_samples - done) |
| prompt_ids = torch.full( |
| (bs, 1), |
| bos_token_id, |
| dtype=torch.long, |
| device=sampler.device, |
| ) |
| out = sampler.generate( |
| prompt_ids=prompt_ids, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| eos_token_id=sampler.tokenizer.eos_token_id, |
| stop_on_eos=stop_on_eos, |
| ) |
| chunks.append(out) |
| done += bs |
| print(f" sampled {done}/{num_samples}") |
| return torch.cat(chunks, dim=0) |
|
|
|
|
| @torch.no_grad() |
| def score_with_eval_lm( |
| texts: list[str], |
| eval_model, |
| eval_tokenizer, |
| device: torch.device, |
| batch_size: int, |
| max_length: int, |
| ) -> dict: |
| total_nll = 0.0 |
| total_tokens = 0 |
| total_acc = 0.0 |
| all_nlls = [] |
|
|
| for i in range(0, len(texts), batch_size): |
| batch = texts[i:i + batch_size] |
| enc = eval_tokenizer( |
| batch, |
| padding=True, |
| return_tensors="pt", |
| truncation=True, |
| max_length=max_length, |
| ).to(device) |
| input_ids = enc["input_ids"] |
| attn_mask = enc["attention_mask"] |
|
|
| outputs = eval_model( |
| input_ids=input_ids, |
| attention_mask=attn_mask, |
| use_cache=False, |
| return_dict=True, |
| ) |
| logits = outputs.logits[:, :-1] |
| labels = input_ids[:, 1:] |
| loss_mask = attn_mask[:, 1:] |
|
|
| nll = F.cross_entropy( |
| logits.transpose(-1, -2), |
| labels, |
| reduction="none", |
| ) |
|
|
| valid = loss_mask.bool() |
| nll_valid = nll[valid] |
| total_nll += nll_valid.sum().item() |
| total_tokens += int(valid.sum().item()) |
| all_nlls.extend(nll_valid.detach().cpu().tolist()) |
|
|
| preds = logits.argmax(dim=-1) |
| total_acc += ((preds == labels).float() * loss_mask).sum().item() |
|
|
| print(f" scored {min(i + batch_size, len(texts))}/{len(texts)}") |
|
|
| if total_tokens == 0: |
| raise RuntimeError("No valid tokens scored - all samples were empty?") |
|
|
| avg_nll = total_nll / total_tokens |
| return { |
| "avg_nll": avg_nll, |
| "median_nll": float(np.median(all_nlls)), |
| "ppl": float(np.exp(avg_nll)), |
| "acc": total_acc / total_tokens, |
| "tokens": total_tokens, |
| } |
|
|
|
|
| def main(): |
| args = parse_args() |
| torch.manual_seed(args.seed) |
|
|
| device = torch.device(args.device) |
| dtype = resolve_dtype(args.dtype) |
| ckpt_path = resolve_path(args.checkpoint) |
| if ckpt_path is None or not ckpt_path.exists(): |
| raise FileNotFoundError(f"checkpoint not found: {args.checkpoint}") |
|
|
| ckpt = torch.load(ckpt_path, map_location=device) |
| if args.config is not None: |
| config = load_config(str(resolve_path(args.config))) |
| config_source = f"cli:{args.config}" |
| else: |
| assert "config" in ckpt, ( |
| "--config was not provided and checkpoint has no embedded 'config' entry." |
| ) |
| config = copy.deepcopy(ckpt["config"]) |
| config_source = f"checkpoint:{args.checkpoint}" |
| print(f"Using config from {config_source}") |
|
|
| tokenizer = build_tokenizer(config) |
| bos_token_id = tokenizer.bos_token_id |
| if bos_token_id is None: |
| raise RuntimeError("tokenizer has no bos_token_id") |
|
|
| model = build_model(config, device).to(dtype) |
| raw_state = ckpt.get("model", ckpt) |
| _unwrap(model).load_state_dict(raw_state, strict=False) |
| model.eval() |
| print(f"Loaded AR checkpoint: {ckpt_path} (step={ckpt.get('step', '?')})") |
|
|
| sampler = ARSampler( |
| model=_unwrap(model), |
| tokenizer=tokenizer, |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| total_seq_len = 1 + args.max_new_tokens |
| print( |
| f"Generating {args.num_samples} samples " |
| f"(seq_len={total_seq_len}, temperature={args.temperature})..." |
| ) |
| tokens = sample_many( |
| sampler=sampler, |
| bos_token_id=bos_token_id, |
| num_samples=args.num_samples, |
| batch_size=args.sample_batch_size, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| stop_on_eos=not args.no_stop_on_eos, |
| ) |
| texts = tokenizer.batch_decode(tokens.tolist(), skip_special_tokens=True) |
| print(f"First sample preview: {texts[0][:120]!r}") |
|
|
| del sampler, model |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| eval_model_path = resolve_path(args.eval_model_path) |
| eval_tok_path = resolve_path(args.eval_tokenizer_path) or eval_model_path |
| print(f"Loading eval LM: {eval_model_path}") |
| print(f"Loading eval tokenizer: {eval_tok_path}") |
|
|
| eval_tokenizer = AutoTokenizer.from_pretrained( |
| str(eval_tok_path), |
| local_files_only=True, |
| ) |
| if eval_tokenizer.pad_token is None: |
| eval_tokenizer.pad_token = eval_tokenizer.eos_token |
|
|
| eval_model = AutoModelForCausalLM.from_pretrained( |
| str(eval_model_path), |
| local_files_only=True, |
| torch_dtype=torch.float32, |
| ).to(device).eval() |
| print(f"Eval LM loaded ({sum(p.numel() for p in eval_model.parameters()):,} params)") |
|
|
| print("Scoring samples under eval LM...") |
| metrics = score_with_eval_lm( |
| texts, |
| eval_model, |
| eval_tokenizer, |
| device, |
| args.eval_batch_size, |
| args.eval_max_length, |
| ) |
| metrics.update({ |
| "checkpoint": str(ckpt_path), |
| "eval_model": str(eval_model_path), |
| "eval_tokenizer": str(eval_tok_path), |
| "num_samples": len(texts), |
| "generated_seq_len": int(tokens.shape[1]), |
| "mode": "ar_generation", |
| "temperature": args.temperature, |
| "top_k": args.top_k, |
| "top_p": args.top_p, |
| "prompt_len": 1, |
| "max_new_tokens": args.max_new_tokens, |
| "stop_on_eos": not args.no_stop_on_eos, |
| }) |
| print(json.dumps(metrics, indent=2)) |
|
|
| out_path = resolve_path(args.output) |
| assert out_path is not None |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(out_path, "w") as f: |
| json.dump(metrics, f, indent=2) |
| print(f"Saved metrics -> {out_path}") |
|
|
| if args.save_samples: |
| s_path = resolve_path(args.save_samples) |
| assert s_path is not None |
| s_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(s_path, "w") as f: |
| json.dump({"samples": texts}, f, indent=2) |
| print(f"Saved samples -> {s_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|