| """ |
| Compute sliding-window perplexity of a trained LLM on a binary token dataset. |
| |
| The sliding-window approach avoids the boundary effect of chunking: a window |
| of ``seq_len`` tokens is evaluated every ``stride`` tokens. Positions in |
| the first (stride) tokens of each window are considered "fresh" context and |
| their NLL contributions are accumulated; positions in the overlap region are |
| not double-counted because only the *new* stride tokens are aggregated at |
| each step. |
| |
| Reference: Press et al., 2022 "Train Short, Test Long" (sliding-window PPL). |
| |
| Usage: |
| python eval/perplexity.py \ |
| --checkpoint checkpoints/checkpoint-0100000 \ |
| --data data/val.bin \ |
| --seq_len 2048 \ |
| --batch_size 4 \ |
| --device cuda:0 \ |
| --stride 512 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import math |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from tqdm import tqdm |
| from model.transformer import LLM |
|
|
|
|
| |
| |
| |
|
|
| class SlidingWindowDataset(Dataset): |
| """ |
| Yields (input_ids, targets, loss_mask) tuples for sliding-window PPL. |
| |
| ``loss_mask`` is 1 for positions that contribute to the perplexity |
| estimate (i.e. the *new* stride tokens at the right end of the window) |
| and 0 for the context-only positions. |
| |
| Args: |
| tokens: Flat 1-D numpy array of token IDs (uint16). |
| seq_len: Context window size. |
| stride: Step size between consecutive windows. |
| """ |
|
|
| def __init__(self, tokens: np.ndarray, seq_len: int, stride: int) -> None: |
| self.tokens = tokens |
| self.seq_len = seq_len |
| self.stride = stride |
| |
| self.n_windows = max(0, (len(tokens) - seq_len + stride - 1) // stride) |
|
|
| def __len__(self) -> int: |
| return self.n_windows |
|
|
| def __getitem__(self, idx: int): |
| start = idx * self.stride |
| end = start + self.seq_len |
|
|
| |
| actual_end = min(end, len(self.tokens)) |
| chunk_len = actual_end - start |
|
|
| input_ids = torch.zeros(self.seq_len, dtype=torch.long) |
| targets = torch.full((self.seq_len,), fill_value=-100, dtype=torch.long) |
| loss_mask = torch.zeros(self.seq_len, dtype=torch.bool) |
|
|
| if chunk_len > 1: |
| toks = torch.from_numpy( |
| self.tokens[start : actual_end].astype(np.int64) |
| ) |
| input_ids[: chunk_len] = toks |
| targets [: chunk_len - 1] = toks[1:] |
|
|
| |
| |
| |
| new_start_in_window = 0 if idx == 0 else self.stride |
| |
| |
| |
| if chunk_len > 1: |
| mask_end = chunk_len - 1 |
| for pos in range(new_start_in_window, mask_end): |
| loss_mask[pos] = True |
|
|
| return input_ids, targets, loss_mask |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def compute_perplexity( |
| model: torch.nn.Module, |
| data_path: str, |
| seq_len: int, |
| batch_size: int, |
| device: str, |
| stride: int, |
| ) -> float: |
| """ |
| Compute sliding-window perplexity on the token file at ``data_path``. |
| |
| Returns: |
| Perplexity (float). |
| """ |
| path = Path(data_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Data file not found: {path}") |
|
|
| tokens = np.memmap(path, dtype="uint16", mode="r") |
| total_tokens = len(tokens) |
| print(f"Loaded {total_tokens:,} tokens from {path}") |
|
|
| dataset = SlidingWindowDataset(tokens, seq_len=seq_len, stride=stride) |
| if len(dataset) == 0: |
| raise ValueError( |
| f"No windows fit in {total_tokens} tokens with seq_len={seq_len}." |
| ) |
|
|
| loader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=0, |
| pin_memory=True, |
| ) |
|
|
| model.eval() |
|
|
| |
| total_nll = 0.0 |
| total_count = 0 |
|
|
| for batch_input_ids, batch_targets, batch_loss_mask in tqdm( |
| loader, desc="Evaluating perplexity", unit="batch" |
| ): |
| batch_input_ids = batch_input_ids.to(device) |
| batch_targets = batch_targets.to(device) |
| batch_loss_mask = batch_loss_mask.to(device) |
|
|
| logits, _ = model(batch_input_ids) |
|
|
| |
| B, S, V = logits.shape |
| ce = F.cross_entropy( |
| logits.reshape(B * S, V), |
| batch_targets.reshape(B * S), |
| ignore_index=-100, |
| reduction="none", |
| ).reshape(B, S) |
|
|
| |
| |
| |
| masked_ce = ce * batch_loss_mask.float() |
| total_nll += masked_ce.sum().item() |
| total_count += batch_loss_mask.sum().item() |
|
|
| if total_count == 0: |
| raise RuntimeError("No valid token positions were evaluated.") |
|
|
| avg_nll = total_nll / total_count |
| perplexity = math.exp(avg_nll) |
| return perplexity |
|
|
|
|
| |
| |
| |
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Compute sliding-window perplexity of a trained LLM." |
| ) |
| parser.add_argument( |
| "--checkpoint", |
| required=True, |
| help="Path to the checkpoint directory.", |
| ) |
| parser.add_argument( |
| "--data", |
| required=True, |
| help="Path to the .bin token data file.", |
| ) |
| parser.add_argument( |
| "--seq_len", |
| type=int, |
| default=2048, |
| help="Context window length (default: 2048).", |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=4, |
| help="Evaluation batch size (default: 4).", |
| ) |
| parser.add_argument( |
| "--device", |
| default="cuda:0", |
| help="Torch device string (default: cuda:0).", |
| ) |
| parser.add_argument( |
| "--stride", |
| type=int, |
| default=512, |
| help=( |
| "Stride for sliding window PPL; smaller = more accurate, " |
| "slower (default: 512)." |
| ), |
| ) |
| return parser.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| ckpt_path = Path(args.checkpoint) |
| if not ckpt_path.exists(): |
| raise FileNotFoundError(f"Checkpoint directory not found: {ckpt_path}") |
|
|
| print(f"Loading model from: {ckpt_path}") |
| model = LLM.from_pretrained(str(ckpt_path)).to(device=args.device, dtype=torch.float16) |
| model.eval() |
| print(f"Model parameters: {model.num_params / 1e6:.1f}M") |
|
|
| print( |
| f"\nPerplexity config: seq_len={args.seq_len}, " |
| f"stride={args.stride}, batch_size={args.batch_size}" |
| ) |
|
|
| ppl = compute_perplexity( |
| model=model, |
| data_path=args.data, |
| seq_len=args.seq_len, |
| batch_size=args.batch_size, |
| device=args.device, |
| stride=args.stride, |
| ) |
|
|
| print("\n" + "=" * 50) |
| print(f" Perplexity: {ppl:.4f}") |
| print(f" Bits/token: {math.log2(math.e) * math.log(ppl):.4f}") |
| print("=" * 50) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|