transformer-weights / scripts /compute_perplexity.py
angerami's picture
chore: align download paths to /Flux/Projects/transformer-analysis/downloads
17a7cd2
Raw
History Blame Contribute Delete
12.6 kB
#!/usr/bin/env python3
"""
Compute perplexity for transformer models on WikiText-103 or The Pile test split.
Outputs a parquet side table (eval_metrics.parquet) joinable onto weight analysis
datasets on (model, revision).
Usage:
python compute_perplexity.py --model gpt2
python compute_perplexity.py --model pythia-70m-deduped --all-revisions --corpus pile --pile-tokens 51200
python compute_perplexity.py --all-models --corpus wikitext103
Output schema: model, revision, step, metric, value, source, corpus
"""
import argparse
import math
import os
import sys
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset
from huggingface_hub import snapshot_download
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from transformer_analysis.model_registry import MODEL_CONFIGS, get_model_config
from transformer_analysis.device_utils import get_device
# ---------------------------------------------------------------------------
# Corpus loading
# ---------------------------------------------------------------------------
def load_pile_cache(pile_cache: str, tokenizer, pile_tokens: int) -> torch.Tensor:
"""Load pre-materialized Pile corpus from a gzipped JSONL file."""
import gzip, json as _json
tokens_collected = []
n = 0
opener = gzip.open if pile_cache.endswith(".gz") else open
with opener(pile_cache, "rt", encoding="utf-8") as f:
for line in f:
text = _json.loads(line)["text"]
enc = tokenizer(text, return_tensors="pt",
truncation=False, add_special_tokens=False)
tokens_collected.append(enc.input_ids[0])
n += len(tokens_collected[-1])
if n >= pile_tokens:
break
if not tokens_collected:
raise ValueError(f"pile_cache file appears empty: {pile_cache}")
return torch.cat(tokens_collected)[:pile_tokens]
def load_corpus_tokens(corpus: str, tokenizer, pile_tokens: int = 204800,
pile_seed: int = 42, pile_cache: Optional[str] = None) -> torch.Tensor:
if corpus == "wikitext103":
ds = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
text = "\n\n".join(t for t in ds["text"] if t.strip())
encodings = tokenizer(text, return_tensors="pt", truncation=False,
add_special_tokens=False)
return encodings.input_ids[0]
elif corpus == "pile":
if pile_cache:
print(f" Loading Pile corpus from cache: {pile_cache}")
return load_pile_cache(pile_cache, tokenizer, pile_tokens)
# Fall back to streaming if no cache provided
print(" No --pile-cache set; streaming from HuggingFace (slow for repeated runs).")
print(" Run prepare_eval_corpus.py once to create a local cache.")
ds = load_dataset("EleutherAI/pile", split="test", streaming=True)
tokens_collected = []
for example in ds.shuffle(seed=pile_seed, buffer_size=1000):
enc = tokenizer(example["text"], return_tensors="pt",
truncation=False, add_special_tokens=False)
tokens_collected.append(enc.input_ids[0])
if sum(len(t) for t in tokens_collected) >= pile_tokens:
break
return torch.cat(tokens_collected)[:pile_tokens]
else:
raise ValueError(f"Unknown corpus: {corpus!r}. Choose 'wikitext103' or 'pile'.")
# ---------------------------------------------------------------------------
# Forward pass and collector pattern
# ---------------------------------------------------------------------------
def run_inference(model, input_ids: torch.Tensor, attention_mask: torch.Tensor,
output_hidden_states: bool = False,
output_attentions: bool = False):
"""
Single forward pass. output_hidden_states / output_attentions are the hooks
for future data-weighted statistics (e.g. <W_QK>_{data}):
- output_hidden_states=True → out.hidden_states[layer] for dressed operators
- output_attentions=True → out.attentions[layer][head] for <A>_{data}
"""
with torch.no_grad():
return model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
class NLLCollector:
"""Accumulates per-token negative log-likelihood for perplexity computation."""
name = "nll"
def __init__(self):
self._total_nll = 0.0
self._n_tokens = 0
def update(self, out, n_tokens: int):
# out.loss is mean NLL over non-masked tokens in the window
self._total_nll += out.loss.item() * n_tokens
self._n_tokens += n_tokens
def result(self):
if self._n_tokens == 0:
return float("nan")
return self._total_nll / self._n_tokens
def eval_loop(model, input_ids: torch.Tensor, device, stride: int = 512,
max_tokens: Optional[int] = None,
collectors=None) -> dict:
"""
Sliding-window perplexity loop (canonical HuggingFace approach).
Collectors accumulate statistics over all windows; extend by adding new
Collector subclasses (e.g. DressedWQKCollector for <W_QK>_{data}).
"""
if collectors is None:
collectors = [NLLCollector()]
max_length = model.config.max_position_embeddings
seq_len = min(len(input_ids), max_tokens or len(input_ids))
input_ids = input_ids[:seq_len].unsqueeze(0).to(device)
prev_end = 0
for begin in range(0, seq_len, stride):
end = min(begin + max_length, seq_len)
target_len = end - prev_end
window = input_ids[:, begin:end]
mask = torch.ones_like(window)
out = run_inference(model, window, mask)
for collector in collectors:
collector.update(out, target_len)
prev_end = end
if end == seq_len:
break
return {c.name: c.result() for c in collectors}
# ---------------------------------------------------------------------------
# Per-model evaluation
# ---------------------------------------------------------------------------
def evaluate_model(model_name: str, revision: Optional[str],
corpus: str, pile_tokens: int, cache_dir: str,
device_str: Optional[str], stride: int = 512,
max_tokens: Optional[int] = None,
pile_cache: Optional[str] = None) -> dict:
model_config = get_model_config(model_name)
revision_str = revision or "main"
print(f" Downloading {model_name} @ {revision_str} ...")
cache_path = snapshot_download(
repo_id=model_config.repo_id,
revision=revision,
cache_dir=f"{cache_dir}/{model_name}/{revision_str}",
allow_patterns=["*.safetensors", "*.bin", "*.json", "tokenizer*"],
resume_download=True,
)
device = torch.device(device_str or get_device())
print(f" Loading model on {device} ...")
tokenizer = AutoTokenizer.from_pretrained(cache_path)
model = AutoModelForCausalLM.from_pretrained(cache_path, torch_dtype=torch.float32)
model = model.to(device).eval()
print(f" Loading corpus ({corpus}) ...")
tokens = load_corpus_tokens(corpus, tokenizer, pile_tokens=pile_tokens,
pile_cache=pile_cache)
print(f" Evaluating on {len(tokens):,} tokens ...")
results = eval_loop(model, tokens, device, stride=stride, max_tokens=max_tokens)
nll = results["nll"]
ppl = math.exp(nll)
bpb = nll / math.log(2)
del model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
step = None
if revision and revision.startswith("step"):
try:
step = int(revision[4:])
except ValueError:
pass
return {
"model": model_name,
"revision": revision_str,
"step": step,
"perplexity": ppl,
"nll": nll,
"bpb": bpb,
"corpus": corpus,
"n_tokens": len(tokens),
}
# ---------------------------------------------------------------------------
# Output helpers
# ---------------------------------------------------------------------------
def to_long_format(row: dict) -> list[dict]:
"""Convert one evaluation result row into long-format (model, revision, step, metric, value, source, corpus)."""
base = {"model": row["model"], "revision": row["revision"],
"step": row["step"], "source": "eval_pass", "corpus": row["corpus"]}
return [
{**base, "metric": "perplexity", "value": row["perplexity"]},
{**base, "metric": "nll", "value": row["nll"]},
{**base, "metric": "bpb", "value": row["bpb"]},
]
def append_to_parquet(rows: list[dict], out_path: str):
new_df = pd.DataFrame(rows)
if os.path.exists(out_path):
existing = pd.read_parquet(out_path)
# Drop any existing rows for (model, revision, corpus) we're replacing
key = ["model", "revision", "corpus"]
mask = existing[key].apply(tuple, axis=1).isin(
new_df[key].apply(tuple, axis=1).unique()
)
existing = existing[~mask]
combined = pd.concat([existing, new_df], ignore_index=True)
else:
combined = new_df
os.makedirs(os.path.dirname(out_path), exist_ok=True)
combined.to_parquet(out_path, index=False)
print(f" Saved {len(new_df)} rows → {out_path}")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Compute perplexity for transformer models")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--model", type=str)
group.add_argument("--all-models", action="store_true")
parser.add_argument("--revision", type=str, default=None)
parser.add_argument("--all-revisions", action="store_true")
parser.add_argument("--corpus", type=str, default="wikitext103",
choices=["wikitext103", "pile"])
parser.add_argument("--pile-tokens", type=int, default=204800,
help="Token count to evaluate from Pile corpus (default: 200K)")
parser.add_argument("--pile-cache", type=str, default=None,
help="Path to pre-materialized Pile corpus (.jsonl.gz) from prepare_eval_corpus.py")
parser.add_argument("--max-tokens", type=int, default=None,
help="Cap total tokens evaluated (default: all)")
parser.add_argument("--stride", type=int, default=512)
parser.add_argument("--out", type=str, default="outputs/eval_metrics/eval_metrics.parquet")
parser.add_argument("--cache", type=str,
default="/Flux/Projects/transformer-analysis/downloads")
parser.add_argument("--device", type=str, default=None, choices=["cuda", "mps", "cpu"])
args = parser.parse_args()
models = list(MODEL_CONFIGS.keys()) if args.all_models else [args.model]
for model_name in models:
try:
model_config = get_model_config(model_name)
except ValueError as e:
print(f"Skipping {model_name}: {e}")
continue
if args.all_revisions:
revisions = model_config.revisions or [None]
elif args.revision:
revisions = [args.revision]
else:
revisions = [None]
print(f"\n{'='*60}\n{model_name}{len(revisions)} revision(s)\n{'='*60}")
rows = []
for rev in tqdm(revisions, desc=model_name):
try:
result = evaluate_model(
model_name=model_name, revision=rev,
corpus=args.corpus, pile_tokens=args.pile_tokens,
cache_dir=args.cache, device_str=args.device,
stride=args.stride, max_tokens=args.max_tokens,
pile_cache=args.pile_cache,
)
print(f" ppl={result['perplexity']:.2f} bpb={result['bpb']:.4f}")
rows.extend(to_long_format(result))
except Exception as e:
print(f" ERROR {model_name} @ {rev}: {e}")
if rows:
append_to_parquet(rows, args.out)
if __name__ == "__main__":
main()