frankenstallm / source /eval /fast_ppl.py
pathcosmos's picture
Upload folder using huggingface_hub (#29)
5b1ff4d
"""
Fast PPL evaluation on B200 — bfloat16, proper CUDA device setup.
Usage:
CUDA_VISIBLE_DEVICES=0 python eval/fast_ppl.py \
--checkpoint checkpoints/korean_3b_fp8_run1/checkpoint-0057000 \
--data data/3b_val.bin \
--max_tokens 10000000 \
--batch_size 32 \
--output eval/outputs/ppl_3b_val.json
"""
from __future__ import annotations
import argparse
import json
import math
import sys
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
from model.transformer import LLM
class SlidingWindowDataset(Dataset):
def __init__(self, tokens: np.ndarray, seq_len: int, stride: int):
self.tokens = tokens
self.seq_len = seq_len
self.stride = stride
self.n_windows = max(0, (len(tokens) - seq_len + stride - 1) // stride)
def __len__(self):
return self.n_windows
def __getitem__(self, idx):
start = idx * self.stride
end = start + self.seq_len
actual_end = min(end, len(self.tokens))
chunk_len = actual_end - start
input_ids = torch.zeros(self.seq_len, dtype=torch.long)
targets = torch.full((self.seq_len,), -100, dtype=torch.long)
loss_mask = torch.zeros(self.seq_len, dtype=torch.bool)
if chunk_len > 1:
toks = torch.from_numpy(self.tokens[start:actual_end].astype(np.int64))
input_ids[:chunk_len] = toks
targets[:chunk_len - 1] = toks[1:]
new_start = 0 if idx == 0 else self.stride
if chunk_len > 1:
for pos in range(new_start, chunk_len - 1):
loss_mask[pos] = True
return input_ids, targets, loss_mask
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--data", required=True)
parser.add_argument("--seq_len", type=int, default=2048)
parser.add_argument("--stride", type=int, default=512)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--max_tokens", type=int, default=0,
help="Max tokens to evaluate (0=all)")
parser.add_argument("--output", default=None, help="JSON output path")
args = parser.parse_args()
device = "cuda:0" # Use CUDA_VISIBLE_DEVICES to select GPU
print(f"Loading model from {args.checkpoint}...")
t0 = time.time()
model = LLM.from_pretrained(args.checkpoint)
model = model.to(device=device, dtype=torch.bfloat16)
model.eval()
num_params = sum(p.numel() for p in model.parameters())
print(f"Model: {num_params/1e6:.1f}M params, bfloat16, loaded in {time.time()-t0:.1f}s")
tokens = np.fromfile(args.data, dtype=np.uint16)
total_tokens = len(tokens)
if args.max_tokens > 0 and total_tokens > args.max_tokens:
tokens = tokens[:args.max_tokens]
print(f"Using {len(tokens):,}/{total_tokens:,} tokens (sampled)")
else:
print(f"Using all {total_tokens:,} tokens")
ds = SlidingWindowDataset(tokens, args.seq_len, args.stride)
dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False,
num_workers=4, pin_memory=True)
n_batches = len(dl)
print(f"Windows: {len(ds):,}, Batches: {n_batches:,}, "
f"seq_len={args.seq_len}, stride={args.stride}, bs={args.batch_size}")
total_nll = 0.0
total_count = 0
t_start = time.time()
with torch.inference_mode():
for i, (inp, tgt, mask) in enumerate(dl):
inp = inp.to(device)
tgt = tgt.to(device)
mask = mask.to(device)
logits, _ = model(inp)
ce = F.cross_entropy(
logits.view(-1, logits.size(-1)),
tgt.view(-1),
reduction="none"
).view(mask.shape)
nll = (ce * mask.float()).sum().item()
cnt = mask.sum().item()
total_nll += nll
total_count += cnt
if (i + 1) % 100 == 0 or (i + 1) == n_batches:
elapsed = time.time() - t_start
running_ppl = math.exp(total_nll / total_count)
speed = (i + 1) / elapsed
eta = (n_batches - i - 1) / speed
print(f" [{i+1}/{n_batches}] PPL={running_ppl:.4f} "
f"({speed:.1f} batch/s, ETA {eta:.0f}s)", flush=True)
elapsed = time.time() - t_start
avg_nll = total_nll / total_count
ppl = math.exp(avg_nll)
bpt = avg_nll / math.log(2)
data_name = Path(args.data).stem
print(f"\n{'='*50}")
print(f" Dataset: {data_name}")
print(f" Tokens evaluated: {total_count:,}")
print(f" Perplexity: {ppl:.4f}")
print(f" Bits/token: {bpt:.4f}")
print(f" Avg NLL: {avg_nll:.6f}")
print(f" Time: {elapsed:.1f}s ({elapsed/60:.1f}min)")
print(f"{'='*50}")
result = {
"dataset": data_name,
"data_file": args.data,
"total_tokens": int(total_tokens),
"eval_tokens": int(total_count),
"max_tokens_used": args.max_tokens if args.max_tokens > 0 else int(total_tokens),
"perplexity": round(ppl, 4),
"bits_per_token": round(bpt, 4),
"avg_nll": round(avg_nll, 6),
"elapsed_sec": round(elapsed, 1),
"config": {
"seq_len": args.seq_len,
"stride": args.stride,
"batch_size": args.batch_size,
"dtype": "bfloat16",
}
}
if args.output:
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print(f"Saved to {args.output}")
return result
if __name__ == "__main__":
main()