File size: 3,211 Bytes
9847679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import annotations

import argparse
import math
import sys
from pathlib import Path

import torch
from torch.utils.data import DataLoader

ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT / "src"))

from sllm.checkpoint import load_checkpoint
from sllm.config import ModelConfig, load_json
from sllm.data import SequentialEvalDataset
from sllm.model import SLLMForCausalLM
from sllm.utils import autocast_context, get_device, resolve_runtime_precision, setup_logger


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Evaluate perplexity on validation shards.")
    parser.add_argument("--checkpoint", required=True, help="Path to checkpoint file.")
    parser.add_argument("--model-config", required=False, help="Optional model config JSON path.")
    parser.add_argument("--data-dir", required=True, help="Validation root directory.")
    parser.add_argument("--seq-len", type=int, default=2_048)
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--batches", type=int, default=50)
    parser.add_argument("--precision", default="bf16")
    return parser


def main() -> None:
    args = build_parser().parse_args()
    logger, log_path = setup_logger("sllm.eval_perplexity", Path("outputs/eval"), "eval_perplexity")
    logger.info("Perplexity evaluation started")
    logger.info("Log file: %s", log_path)
    logger.info("Arguments | checkpoint=%s model_config=%s data_dir=%s seq_len=%s batch_size=%s batches=%s precision=%s", args.checkpoint, args.model_config, args.data_dir, args.seq_len, args.batch_size, args.batches, args.precision)
    device = get_device()
    runtime_precision, precision_warning = resolve_runtime_precision(device, args.precision)
    if precision_warning is not None:
        logger.warning(precision_warning)
    payload = load_checkpoint(args.checkpoint, map_location=device)
    if args.model_config:
        model_config = ModelConfig.from_dict(load_json(args.model_config))
    else:
        model_config = ModelConfig.from_dict(payload["model_config"])

    model = SLLMForCausalLM(model_config).to(device)
    model.load_state_dict(payload["model"])
    model.eval()

    dataset = SequentialEvalDataset(
        data_dir=args.data_dir,
        split="val",
        seq_len=args.seq_len,
        max_batches=args.batches * args.batch_size,
    )
    loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0)

    losses = []
    with torch.no_grad():
        for batch_index, batch in enumerate(loader):
            if batch_index >= args.batches:
                break
            batch = {key: value.to(device) for key, value in batch.items()}
            with autocast_context(device, runtime_precision):
                loss = model(**batch)["loss"]
            losses.append(loss.detach().float().item())

    mean_loss = float(sum(losses) / max(1, len(losses)))
    perplexity = math.exp(min(mean_loss, 20))
    logger.info("Perplexity evaluation finished | val_loss=%.4f perplexity=%.2f", mean_loss, perplexity)
    print(f"val_loss={mean_loss:.4f}")
    print(f"perplexity={perplexity:.2f}")


if __name__ == "__main__":
    main()