File size: 2,758 Bytes
93462da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
"""
Evaluate a student checkpoint against the frozen teacher using the same
single-process sharded setup and fixed eval cache as distill_sharded.py.
"""

from __future__ import annotations

import argparse
from pathlib import Path

import torch

import distill_sharded as ds


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--config", required=True)
    p.add_argument("--student", default=None, help="Optional student override path")
    p.add_argument("--samples", type=int, default=None, help="Optional eval sample override")
    args = p.parse_args()

    cfg = ds.load_config(args.config)
    if args.student:
        cfg["model"]["student"] = args.student
    if args.samples:
        cfg["eval"]["samples"] = args.samples

    student_device = torch.device(cfg["model"]["student_device"])
    teacher_devices = list(cfg["model"]["teacher_devices"])

    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"], trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    pad_id = tokenizer.pad_token_id

    student = ds.load_student(
        cfg["model"]["student"],
        ds.parse_dtype(cfg["train"]["student_dtype"]),
        grad_ckpt=False,
        attn_impl=cfg["train"]["attn_implementation"],
    )
    student.to(student_device)
    student.eval()

    teacher = ds.load_teacher(
        cfg["model"]["teacher"],
        ds.parse_dtype(cfg["train"]["teacher_dtype"]),
        attn_impl=cfg["train"]["attn_implementation"],
        devices=teacher_devices,
        max_mem_gb=cfg["model"]["teacher_max_memory_gb"],
    )
    teacher_input_device, _ = ds.get_teacher_devices(teacher)

    specs = ds.build_dataset_specs(cfg["data"])
    if Path(cfg["eval"]["cache_path"]).exists():
        eval_batches = ds.build_or_load_eval_cache(cfg["eval"]["cache_path"])
    else:
        eval_loader = ds.MixedStreamingLoader(
            specs=specs,
            tokenizer=tokenizer,
            min_chars=cfg["data"]["min_chars"],
            max_seq_len=cfg["data"]["max_seq_len"],
            kl_start_pos=cfg["data"]["kl_start_pos"],
            seed=cfg["eval"]["seed"],
            shuffle_buffer=cfg["data"]["shuffle_buffer"],
        )
        eval_batches = ds.build_or_load_eval_cache(
            cfg["eval"]["cache_path"],
            eval_loader,
            cfg["eval"]["samples"],
        )
    kl = ds.evaluate(
        student,
        teacher,
        eval_batches,
        pad_id,
        cfg["data"]["kl_start_pos"],
        cfg["train"]["kl_chunk_size"],
        student_device,
        teacher_input_device,
    )
    print(f"{kl:.6f}")


if __name__ == "__main__":
    main()