| """ |
| Fast PPL evaluation on B200 — bfloat16, proper CUDA device setup. |
| |
| Usage: |
| CUDA_VISIBLE_DEVICES=0 python eval/fast_ppl.py \ |
| --checkpoint checkpoints/korean_3b_fp8_run1/checkpoint-0057000 \ |
| --data data/3b_val.bin \ |
| --max_tokens 10000000 \ |
| --batch_size 32 \ |
| --output eval/outputs/ppl_3b_val.json |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
|
|
| _PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| if str(_PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(_PROJECT_ROOT)) |
|
|
| from model.transformer import LLM |
|
|
|
|
| class SlidingWindowDataset(Dataset): |
| def __init__(self, tokens: np.ndarray, seq_len: int, stride: int): |
| self.tokens = tokens |
| self.seq_len = seq_len |
| self.stride = stride |
| self.n_windows = max(0, (len(tokens) - seq_len + stride - 1) // stride) |
|
|
| def __len__(self): |
| return self.n_windows |
|
|
| def __getitem__(self, idx): |
| start = idx * self.stride |
| end = start + self.seq_len |
| actual_end = min(end, len(self.tokens)) |
| chunk_len = actual_end - start |
|
|
| input_ids = torch.zeros(self.seq_len, dtype=torch.long) |
| targets = torch.full((self.seq_len,), -100, dtype=torch.long) |
| loss_mask = torch.zeros(self.seq_len, dtype=torch.bool) |
|
|
| if chunk_len > 1: |
| toks = torch.from_numpy(self.tokens[start:actual_end].astype(np.int64)) |
| input_ids[:chunk_len] = toks |
| targets[:chunk_len - 1] = toks[1:] |
| new_start = 0 if idx == 0 else self.stride |
| if chunk_len > 1: |
| for pos in range(new_start, chunk_len - 1): |
| loss_mask[pos] = True |
| return input_ids, targets, loss_mask |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", required=True) |
| parser.add_argument("--data", required=True) |
| parser.add_argument("--seq_len", type=int, default=2048) |
| parser.add_argument("--stride", type=int, default=512) |
| parser.add_argument("--batch_size", type=int, default=32) |
| parser.add_argument("--max_tokens", type=int, default=0, |
| help="Max tokens to evaluate (0=all)") |
| parser.add_argument("--output", default=None, help="JSON output path") |
| args = parser.parse_args() |
|
|
| device = "cuda:0" |
|
|
| print(f"Loading model from {args.checkpoint}...") |
| t0 = time.time() |
| model = LLM.from_pretrained(args.checkpoint) |
| model = model.to(device=device, dtype=torch.bfloat16) |
| model.eval() |
| num_params = sum(p.numel() for p in model.parameters()) |
| print(f"Model: {num_params/1e6:.1f}M params, bfloat16, loaded in {time.time()-t0:.1f}s") |
|
|
| tokens = np.fromfile(args.data, dtype=np.uint16) |
| total_tokens = len(tokens) |
| if args.max_tokens > 0 and total_tokens > args.max_tokens: |
| tokens = tokens[:args.max_tokens] |
| print(f"Using {len(tokens):,}/{total_tokens:,} tokens (sampled)") |
| else: |
| print(f"Using all {total_tokens:,} tokens") |
|
|
| ds = SlidingWindowDataset(tokens, args.seq_len, args.stride) |
| dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=4, pin_memory=True) |
| n_batches = len(dl) |
| print(f"Windows: {len(ds):,}, Batches: {n_batches:,}, " |
| f"seq_len={args.seq_len}, stride={args.stride}, bs={args.batch_size}") |
|
|
| total_nll = 0.0 |
| total_count = 0 |
| t_start = time.time() |
|
|
| with torch.inference_mode(): |
| for i, (inp, tgt, mask) in enumerate(dl): |
| inp = inp.to(device) |
| tgt = tgt.to(device) |
| mask = mask.to(device) |
|
|
| logits, _ = model(inp) |
| ce = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| tgt.view(-1), |
| reduction="none" |
| ).view(mask.shape) |
|
|
| nll = (ce * mask.float()).sum().item() |
| cnt = mask.sum().item() |
| total_nll += nll |
| total_count += cnt |
|
|
| if (i + 1) % 100 == 0 or (i + 1) == n_batches: |
| elapsed = time.time() - t_start |
| running_ppl = math.exp(total_nll / total_count) |
| speed = (i + 1) / elapsed |
| eta = (n_batches - i - 1) / speed |
| print(f" [{i+1}/{n_batches}] PPL={running_ppl:.4f} " |
| f"({speed:.1f} batch/s, ETA {eta:.0f}s)", flush=True) |
|
|
| elapsed = time.time() - t_start |
| avg_nll = total_nll / total_count |
| ppl = math.exp(avg_nll) |
| bpt = avg_nll / math.log(2) |
|
|
| data_name = Path(args.data).stem |
| print(f"\n{'='*50}") |
| print(f" Dataset: {data_name}") |
| print(f" Tokens evaluated: {total_count:,}") |
| print(f" Perplexity: {ppl:.4f}") |
| print(f" Bits/token: {bpt:.4f}") |
| print(f" Avg NLL: {avg_nll:.6f}") |
| print(f" Time: {elapsed:.1f}s ({elapsed/60:.1f}min)") |
| print(f"{'='*50}") |
|
|
| result = { |
| "dataset": data_name, |
| "data_file": args.data, |
| "total_tokens": int(total_tokens), |
| "eval_tokens": int(total_count), |
| "max_tokens_used": args.max_tokens if args.max_tokens > 0 else int(total_tokens), |
| "perplexity": round(ppl, 4), |
| "bits_per_token": round(bpt, 4), |
| "avg_nll": round(avg_nll, 6), |
| "elapsed_sec": round(elapsed, 1), |
| "config": { |
| "seq_len": args.seq_len, |
| "stride": args.stride, |
| "batch_size": args.batch_size, |
| "dtype": "bfloat16", |
| } |
| } |
|
|
| if args.output: |
| Path(args.output).parent.mkdir(parents=True, exist_ok=True) |
| with open(args.output, "w") as f: |
| json.dump(result, f, indent=2, ensure_ascii=False) |
| print(f"Saved to {args.output}") |
|
|
| return result |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|