#!/usr/bin/env python3 """ Compute perplexity for transformer models on WikiText-103 or The Pile test split. Outputs a parquet side table (eval_metrics.parquet) joinable onto weight analysis datasets on (model, revision). Usage: python compute_perplexity.py --model gpt2 python compute_perplexity.py --model pythia-70m-deduped --all-revisions --corpus pile --pile-tokens 51200 python compute_perplexity.py --all-models --corpus wikitext103 Output schema: model, revision, step, metric, value, source, corpus """ import argparse import math import os import sys from pathlib import Path from typing import Optional import numpy as np import pandas as pd import torch from datasets import load_dataset from huggingface_hub import snapshot_download from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from transformer_analysis.model_registry import MODEL_CONFIGS, get_model_config from transformer_analysis.device_utils import get_device # --------------------------------------------------------------------------- # Corpus loading # --------------------------------------------------------------------------- def load_pile_cache(pile_cache: str, tokenizer, pile_tokens: int) -> torch.Tensor: """Load pre-materialized Pile corpus from a gzipped JSONL file.""" import gzip, json as _json tokens_collected = [] n = 0 opener = gzip.open if pile_cache.endswith(".gz") else open with opener(pile_cache, "rt", encoding="utf-8") as f: for line in f: text = _json.loads(line)["text"] enc = tokenizer(text, return_tensors="pt", truncation=False, add_special_tokens=False) tokens_collected.append(enc.input_ids[0]) n += len(tokens_collected[-1]) if n >= pile_tokens: break if not tokens_collected: raise ValueError(f"pile_cache file appears empty: {pile_cache}") return torch.cat(tokens_collected)[:pile_tokens] def load_corpus_tokens(corpus: str, tokenizer, pile_tokens: int = 204800, pile_seed: int = 42, pile_cache: Optional[str] = None) -> torch.Tensor: if corpus == "wikitext103": ds = load_dataset("wikitext", "wikitext-103-raw-v1", split="test") text = "\n\n".join(t for t in ds["text"] if t.strip()) encodings = tokenizer(text, return_tensors="pt", truncation=False, add_special_tokens=False) return encodings.input_ids[0] elif corpus == "pile": if pile_cache: print(f" Loading Pile corpus from cache: {pile_cache}") return load_pile_cache(pile_cache, tokenizer, pile_tokens) # Fall back to streaming if no cache provided print(" No --pile-cache set; streaming from HuggingFace (slow for repeated runs).") print(" Run prepare_eval_corpus.py once to create a local cache.") ds = load_dataset("EleutherAI/pile", split="test", streaming=True) tokens_collected = [] for example in ds.shuffle(seed=pile_seed, buffer_size=1000): enc = tokenizer(example["text"], return_tensors="pt", truncation=False, add_special_tokens=False) tokens_collected.append(enc.input_ids[0]) if sum(len(t) for t in tokens_collected) >= pile_tokens: break return torch.cat(tokens_collected)[:pile_tokens] else: raise ValueError(f"Unknown corpus: {corpus!r}. Choose 'wikitext103' or 'pile'.") # --------------------------------------------------------------------------- # Forward pass and collector pattern # --------------------------------------------------------------------------- def run_inference(model, input_ids: torch.Tensor, attention_mask: torch.Tensor, output_hidden_states: bool = False, output_attentions: bool = False): """ Single forward pass. output_hidden_states / output_attentions are the hooks for future data-weighted statistics (e.g. _{data}): - output_hidden_states=True → out.hidden_states[layer] for dressed operators - output_attentions=True → out.attentions[layer][head] for _{data} """ with torch.no_grad(): return model( input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) class NLLCollector: """Accumulates per-token negative log-likelihood for perplexity computation.""" name = "nll" def __init__(self): self._total_nll = 0.0 self._n_tokens = 0 def update(self, out, n_tokens: int): # out.loss is mean NLL over non-masked tokens in the window self._total_nll += out.loss.item() * n_tokens self._n_tokens += n_tokens def result(self): if self._n_tokens == 0: return float("nan") return self._total_nll / self._n_tokens def eval_loop(model, input_ids: torch.Tensor, device, stride: int = 512, max_tokens: Optional[int] = None, collectors=None) -> dict: """ Sliding-window perplexity loop (canonical HuggingFace approach). Collectors accumulate statistics over all windows; extend by adding new Collector subclasses (e.g. DressedWQKCollector for _{data}). """ if collectors is None: collectors = [NLLCollector()] max_length = model.config.max_position_embeddings seq_len = min(len(input_ids), max_tokens or len(input_ids)) input_ids = input_ids[:seq_len].unsqueeze(0).to(device) prev_end = 0 for begin in range(0, seq_len, stride): end = min(begin + max_length, seq_len) target_len = end - prev_end window = input_ids[:, begin:end] mask = torch.ones_like(window) out = run_inference(model, window, mask) for collector in collectors: collector.update(out, target_len) prev_end = end if end == seq_len: break return {c.name: c.result() for c in collectors} # --------------------------------------------------------------------------- # Per-model evaluation # --------------------------------------------------------------------------- def evaluate_model(model_name: str, revision: Optional[str], corpus: str, pile_tokens: int, cache_dir: str, device_str: Optional[str], stride: int = 512, max_tokens: Optional[int] = None, pile_cache: Optional[str] = None) -> dict: model_config = get_model_config(model_name) revision_str = revision or "main" print(f" Downloading {model_name} @ {revision_str} ...") cache_path = snapshot_download( repo_id=model_config.repo_id, revision=revision, cache_dir=f"{cache_dir}/{model_name}/{revision_str}", allow_patterns=["*.safetensors", "*.bin", "*.json", "tokenizer*"], resume_download=True, ) device = torch.device(device_str or get_device()) print(f" Loading model on {device} ...") tokenizer = AutoTokenizer.from_pretrained(cache_path) model = AutoModelForCausalLM.from_pretrained(cache_path, torch_dtype=torch.float32) model = model.to(device).eval() print(f" Loading corpus ({corpus}) ...") tokens = load_corpus_tokens(corpus, tokenizer, pile_tokens=pile_tokens, pile_cache=pile_cache) print(f" Evaluating on {len(tokens):,} tokens ...") results = eval_loop(model, tokens, device, stride=stride, max_tokens=max_tokens) nll = results["nll"] ppl = math.exp(nll) bpb = nll / math.log(2) del model torch.cuda.empty_cache() if torch.cuda.is_available() else None step = None if revision and revision.startswith("step"): try: step = int(revision[4:]) except ValueError: pass return { "model": model_name, "revision": revision_str, "step": step, "perplexity": ppl, "nll": nll, "bpb": bpb, "corpus": corpus, "n_tokens": len(tokens), } # --------------------------------------------------------------------------- # Output helpers # --------------------------------------------------------------------------- def to_long_format(row: dict) -> list[dict]: """Convert one evaluation result row into long-format (model, revision, step, metric, value, source, corpus).""" base = {"model": row["model"], "revision": row["revision"], "step": row["step"], "source": "eval_pass", "corpus": row["corpus"]} return [ {**base, "metric": "perplexity", "value": row["perplexity"]}, {**base, "metric": "nll", "value": row["nll"]}, {**base, "metric": "bpb", "value": row["bpb"]}, ] def append_to_parquet(rows: list[dict], out_path: str): new_df = pd.DataFrame(rows) if os.path.exists(out_path): existing = pd.read_parquet(out_path) # Drop any existing rows for (model, revision, corpus) we're replacing key = ["model", "revision", "corpus"] mask = existing[key].apply(tuple, axis=1).isin( new_df[key].apply(tuple, axis=1).unique() ) existing = existing[~mask] combined = pd.concat([existing, new_df], ignore_index=True) else: combined = new_df os.makedirs(os.path.dirname(out_path), exist_ok=True) combined.to_parquet(out_path, index=False) print(f" Saved {len(new_df)} rows → {out_path}") # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="Compute perplexity for transformer models") group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--model", type=str) group.add_argument("--all-models", action="store_true") parser.add_argument("--revision", type=str, default=None) parser.add_argument("--all-revisions", action="store_true") parser.add_argument("--corpus", type=str, default="wikitext103", choices=["wikitext103", "pile"]) parser.add_argument("--pile-tokens", type=int, default=204800, help="Token count to evaluate from Pile corpus (default: 200K)") parser.add_argument("--pile-cache", type=str, default=None, help="Path to pre-materialized Pile corpus (.jsonl.gz) from prepare_eval_corpus.py") parser.add_argument("--max-tokens", type=int, default=None, help="Cap total tokens evaluated (default: all)") parser.add_argument("--stride", type=int, default=512) parser.add_argument("--out", type=str, default="outputs/eval_metrics/eval_metrics.parquet") parser.add_argument("--cache", type=str, default="/Flux/Projects/transformer-analysis/downloads") parser.add_argument("--device", type=str, default=None, choices=["cuda", "mps", "cpu"]) args = parser.parse_args() models = list(MODEL_CONFIGS.keys()) if args.all_models else [args.model] for model_name in models: try: model_config = get_model_config(model_name) except ValueError as e: print(f"Skipping {model_name}: {e}") continue if args.all_revisions: revisions = model_config.revisions or [None] elif args.revision: revisions = [args.revision] else: revisions = [None] print(f"\n{'='*60}\n{model_name} — {len(revisions)} revision(s)\n{'='*60}") rows = [] for rev in tqdm(revisions, desc=model_name): try: result = evaluate_model( model_name=model_name, revision=rev, corpus=args.corpus, pile_tokens=args.pile_tokens, cache_dir=args.cache, device_str=args.device, stride=args.stride, max_tokens=args.max_tokens, pile_cache=args.pile_cache, ) print(f" ppl={result['perplexity']:.2f} bpb={result['bpb']:.4f}") rows.extend(to_long_format(result)) except Exception as e: print(f" ERROR {model_name} @ {rev}: {e}") if rows: append_to_parquet(rows, args.out) if __name__ == "__main__": main()