sad / scripts /eval_ar_gen_ppl.py
haochengsama's picture
Add files using upload-large-folder tool
8b0aeb2 verified
Raw
History Blame Contribute Delete
10.4 kB
#!/usr/bin/env python3
"""
eval_ar_gen_ppl.py - Generative perplexity of AR baseline samples under an eval LM.
Mirrors eval_gen_ppl.py for the autoregressive baseline:
1. Draw N unconditional samples from a trained AR checkpoint.
2. Decode them into text with the GPT-2 tokenizer used for training.
3. Score them under a pretrained eval LM (default: local gpt2-large).
4. Report avg_nll / ppl / acc and optionally save the samples.
"""
from __future__ import annotations
import argparse
import copy
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1] # sad/
import numpy as np
import torch
import torch.nn.functional as F
sys.path.insert(0, str(ROOT))
sys.path.insert(0, str(Path(__file__).parent)) # for `inference_ar`
from inference_ar import ARSampler, build_model, build_tokenizer, load_config, resolve_dtype, _unwrap
def resolve_path(raw: str | None) -> Path | None:
if raw is None:
return None
path = Path(raw)
if path.is_absolute():
return path
return ROOT / path
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--checkpoint", type=str, required=True,
help="AR checkpoint.")
p.add_argument("--config", type=str, default=None,
help="Optional config path. If omitted, uses the config "
"stored inside --checkpoint.")
p.add_argument("--num_samples", type=int, default=256,
help="Total unconditional samples to generate.")
p.add_argument("--sample_batch_size", type=int, default=16,
help="Batch size for AR sampling.")
p.add_argument("--eval_batch_size", type=int, default=8,
help="Batch size when feeding samples to the eval LM.")
p.add_argument("--eval_model_path", type=str, default="models/gpt2-large")
p.add_argument("--eval_tokenizer_path", type=str, default=None,
help="Defaults to --eval_model_path when omitted.")
p.add_argument("--eval_max_length", type=int, default=1024,
help="Truncation length for eval-LM tokenization.")
p.add_argument("--max_new_tokens", type=int, default=511,
help="Number of new tokens sampled after the BOS prompt.")
p.add_argument("--temperature", type=float, default=1.0,
help="Sampling temperature. 0 means greedy decoding.")
p.add_argument("--top_k", type=int, default=0,
help="0 disables top-k sampling.")
p.add_argument("--top_p", type=float, default=1.0,
help="1.0 disables top-p sampling.")
p.add_argument("--no-stop-on-eos", action="store_true",
help="Keep sampling until max_new_tokens is reached.")
p.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--dtype", type=str, default="bf16",
choices=["bf16", "fp16", "fp32"],
help="dtype for AR sampling (eval LM always runs fp32).")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--output", type=str, default="outputs/ar_gen_ppl_metrics.json")
p.add_argument("--save_samples", type=str, default=None,
help="Optional path to dump decoded text samples (JSON).")
return p.parse_args()
@torch.no_grad()
def sample_many(
sampler: ARSampler,
bos_token_id: int,
num_samples: int,
batch_size: int,
max_new_tokens: int,
temperature: float,
top_k: int,
top_p: float,
stop_on_eos: bool,
) -> torch.Tensor:
chunks = []
done = 0
while done < num_samples:
bs = min(batch_size, num_samples - done)
prompt_ids = torch.full(
(bs, 1),
bos_token_id,
dtype=torch.long,
device=sampler.device,
)
out = sampler.generate(
prompt_ids=prompt_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=sampler.tokenizer.eos_token_id,
stop_on_eos=stop_on_eos,
)
chunks.append(out)
done += bs
print(f" sampled {done}/{num_samples}")
return torch.cat(chunks, dim=0)
@torch.no_grad()
def score_with_eval_lm(
texts: list[str],
eval_model,
eval_tokenizer,
device: torch.device,
batch_size: int,
max_length: int,
) -> dict:
total_nll = 0.0
total_tokens = 0
total_acc = 0.0
all_nlls = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
enc = eval_tokenizer(
batch,
padding=True,
return_tensors="pt",
truncation=True,
max_length=max_length,
).to(device)
input_ids = enc["input_ids"]
attn_mask = enc["attention_mask"]
outputs = eval_model(
input_ids=input_ids,
attention_mask=attn_mask,
use_cache=False,
return_dict=True,
)
logits = outputs.logits[:, :-1]
labels = input_ids[:, 1:]
loss_mask = attn_mask[:, 1:]
nll = F.cross_entropy(
logits.transpose(-1, -2),
labels,
reduction="none",
)
valid = loss_mask.bool()
nll_valid = nll[valid]
total_nll += nll_valid.sum().item()
total_tokens += int(valid.sum().item())
all_nlls.extend(nll_valid.detach().cpu().tolist())
preds = logits.argmax(dim=-1)
total_acc += ((preds == labels).float() * loss_mask).sum().item()
print(f" scored {min(i + batch_size, len(texts))}/{len(texts)}")
if total_tokens == 0:
raise RuntimeError("No valid tokens scored - all samples were empty?")
avg_nll = total_nll / total_tokens
return {
"avg_nll": avg_nll,
"median_nll": float(np.median(all_nlls)),
"ppl": float(np.exp(avg_nll)),
"acc": total_acc / total_tokens,
"tokens": total_tokens,
}
def main():
args = parse_args()
torch.manual_seed(args.seed)
device = torch.device(args.device)
dtype = resolve_dtype(args.dtype)
ckpt_path = resolve_path(args.checkpoint)
if ckpt_path is None or not ckpt_path.exists():
raise FileNotFoundError(f"checkpoint not found: {args.checkpoint}")
ckpt = torch.load(ckpt_path, map_location=device)
if args.config is not None:
config = load_config(str(resolve_path(args.config)))
config_source = f"cli:{args.config}"
else:
assert "config" in ckpt, (
"--config was not provided and checkpoint has no embedded 'config' entry."
)
config = copy.deepcopy(ckpt["config"])
config_source = f"checkpoint:{args.checkpoint}"
print(f"Using config from {config_source}")
tokenizer = build_tokenizer(config)
bos_token_id = tokenizer.bos_token_id
if bos_token_id is None:
raise RuntimeError("tokenizer has no bos_token_id")
model = build_model(config, device).to(dtype)
raw_state = ckpt.get("model", ckpt)
_unwrap(model).load_state_dict(raw_state, strict=False)
model.eval()
print(f"Loaded AR checkpoint: {ckpt_path} (step={ckpt.get('step', '?')})")
sampler = ARSampler(
model=_unwrap(model),
tokenizer=tokenizer,
device=device,
dtype=dtype,
)
total_seq_len = 1 + args.max_new_tokens
print(
f"Generating {args.num_samples} samples "
f"(seq_len={total_seq_len}, temperature={args.temperature})..."
)
tokens = sample_many(
sampler=sampler,
bos_token_id=bos_token_id,
num_samples=args.num_samples,
batch_size=args.sample_batch_size,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
stop_on_eos=not args.no_stop_on_eos,
)
texts = tokenizer.batch_decode(tokens.tolist(), skip_special_tokens=True)
print(f"First sample preview: {texts[0][:120]!r}")
del sampler, model
if device.type == "cuda":
torch.cuda.empty_cache()
from transformers import AutoModelForCausalLM, AutoTokenizer
eval_model_path = resolve_path(args.eval_model_path)
eval_tok_path = resolve_path(args.eval_tokenizer_path) or eval_model_path
print(f"Loading eval LM: {eval_model_path}")
print(f"Loading eval tokenizer: {eval_tok_path}")
eval_tokenizer = AutoTokenizer.from_pretrained(
str(eval_tok_path),
local_files_only=True,
)
if eval_tokenizer.pad_token is None:
eval_tokenizer.pad_token = eval_tokenizer.eos_token
eval_model = AutoModelForCausalLM.from_pretrained(
str(eval_model_path),
local_files_only=True,
torch_dtype=torch.float32,
).to(device).eval()
print(f"Eval LM loaded ({sum(p.numel() for p in eval_model.parameters()):,} params)")
print("Scoring samples under eval LM...")
metrics = score_with_eval_lm(
texts,
eval_model,
eval_tokenizer,
device,
args.eval_batch_size,
args.eval_max_length,
)
metrics.update({
"checkpoint": str(ckpt_path),
"eval_model": str(eval_model_path),
"eval_tokenizer": str(eval_tok_path),
"num_samples": len(texts),
"generated_seq_len": int(tokens.shape[1]),
"mode": "ar_generation",
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
"prompt_len": 1,
"max_new_tokens": args.max_new_tokens,
"stop_on_eos": not args.no_stop_on_eos,
})
print(json.dumps(metrics, indent=2))
out_path = resolve_path(args.output)
assert out_path is not None
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "w") as f:
json.dump(metrics, f, indent=2)
print(f"Saved metrics -> {out_path}")
if args.save_samples:
s_path = resolve_path(args.save_samples)
assert s_path is not None
s_path.parent.mkdir(parents=True, exist_ok=True)
with open(s_path, "w") as f:
json.dump({"samples": texts}, f, indent=2)
print(f"Saved samples -> {s_path}")
if __name__ == "__main__":
main()