QED-75M_artifacts / scripts /eval_perplexity.py
levossadtchi's picture
Add files using upload-large-folder tool
9847679 verified
from __future__ import annotations
import argparse
import math
import sys
from pathlib import Path
import torch
from torch.utils.data import DataLoader
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT / "src"))
from sllm.checkpoint import load_checkpoint
from sllm.config import ModelConfig, load_json
from sllm.data import SequentialEvalDataset
from sllm.model import SLLMForCausalLM
from sllm.utils import autocast_context, get_device, resolve_runtime_precision, setup_logger
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Evaluate perplexity on validation shards.")
parser.add_argument("--checkpoint", required=True, help="Path to checkpoint file.")
parser.add_argument("--model-config", required=False, help="Optional model config JSON path.")
parser.add_argument("--data-dir", required=True, help="Validation root directory.")
parser.add_argument("--seq-len", type=int, default=2_048)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--batches", type=int, default=50)
parser.add_argument("--precision", default="bf16")
return parser
def main() -> None:
args = build_parser().parse_args()
logger, log_path = setup_logger("sllm.eval_perplexity", Path("outputs/eval"), "eval_perplexity")
logger.info("Perplexity evaluation started")
logger.info("Log file: %s", log_path)
logger.info("Arguments | checkpoint=%s model_config=%s data_dir=%s seq_len=%s batch_size=%s batches=%s precision=%s", args.checkpoint, args.model_config, args.data_dir, args.seq_len, args.batch_size, args.batches, args.precision)
device = get_device()
runtime_precision, precision_warning = resolve_runtime_precision(device, args.precision)
if precision_warning is not None:
logger.warning(precision_warning)
payload = load_checkpoint(args.checkpoint, map_location=device)
if args.model_config:
model_config = ModelConfig.from_dict(load_json(args.model_config))
else:
model_config = ModelConfig.from_dict(payload["model_config"])
model = SLLMForCausalLM(model_config).to(device)
model.load_state_dict(payload["model"])
model.eval()
dataset = SequentialEvalDataset(
data_dir=args.data_dir,
split="val",
seq_len=args.seq_len,
max_batches=args.batches * args.batch_size,
)
loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0)
losses = []
with torch.no_grad():
for batch_index, batch in enumerate(loader):
if batch_index >= args.batches:
break
batch = {key: value.to(device) for key, value in batch.items()}
with autocast_context(device, runtime_precision):
loss = model(**batch)["loss"]
losses.append(loss.detach().float().item())
mean_loss = float(sum(losses) / max(1, len(losses)))
perplexity = math.exp(min(mean_loss, 20))
logger.info("Perplexity evaluation finished | val_loss=%.4f perplexity=%.2f", mean_loss, perplexity)
print(f"val_loss={mean_loss:.4f}")
print(f"perplexity={perplexity:.2f}")
if __name__ == "__main__":
main()