from __future__ import annotations import argparse import math import sys from pathlib import Path import torch from torch.utils.data import DataLoader ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT / "src")) from sllm.checkpoint import load_checkpoint from sllm.config import ModelConfig, load_json from sllm.data import SequentialEvalDataset from sllm.model import SLLMForCausalLM from sllm.utils import autocast_context, get_device, resolve_runtime_precision, setup_logger def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Evaluate perplexity on validation shards.") parser.add_argument("--checkpoint", required=True, help="Path to checkpoint file.") parser.add_argument("--model-config", required=False, help="Optional model config JSON path.") parser.add_argument("--data-dir", required=True, help="Validation root directory.") parser.add_argument("--seq-len", type=int, default=2_048) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batches", type=int, default=50) parser.add_argument("--precision", default="bf16") return parser def main() -> None: args = build_parser().parse_args() logger, log_path = setup_logger("sllm.eval_perplexity", Path("outputs/eval"), "eval_perplexity") logger.info("Perplexity evaluation started") logger.info("Log file: %s", log_path) logger.info("Arguments | checkpoint=%s model_config=%s data_dir=%s seq_len=%s batch_size=%s batches=%s precision=%s", args.checkpoint, args.model_config, args.data_dir, args.seq_len, args.batch_size, args.batches, args.precision) device = get_device() runtime_precision, precision_warning = resolve_runtime_precision(device, args.precision) if precision_warning is not None: logger.warning(precision_warning) payload = load_checkpoint(args.checkpoint, map_location=device) if args.model_config: model_config = ModelConfig.from_dict(load_json(args.model_config)) else: model_config = ModelConfig.from_dict(payload["model_config"]) model = SLLMForCausalLM(model_config).to(device) model.load_state_dict(payload["model"]) model.eval() dataset = SequentialEvalDataset( data_dir=args.data_dir, split="val", seq_len=args.seq_len, max_batches=args.batches * args.batch_size, ) loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0) losses = [] with torch.no_grad(): for batch_index, batch in enumerate(loader): if batch_index >= args.batches: break batch = {key: value.to(device) for key, value in batch.items()} with autocast_context(device, runtime_precision): loss = model(**batch)["loss"] losses.append(loss.detach().float().item()) mean_loss = float(sum(losses) / max(1, len(losses))) perplexity = math.exp(min(mean_loss, 20)) logger.info("Perplexity evaluation finished | val_loss=%.4f perplexity=%.2f", mean_loss, perplexity) print(f"val_loss={mean_loss:.4f}") print(f"perplexity={perplexity:.2f}") if __name__ == "__main__": main()