File size: 5,144 Bytes
2c44909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python3
"""Evaluate perplexity for a progressive-pruned model assembled from cycles."""

import argparse

import torch

try:
    import ppl_eval
except Exception as exc:  # pragma: no cover - optional dependency
    raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc

try:
    from transformers import AutoTokenizer
except Exception as exc:  # pragma: no cover - fail early with clear error
    raise SystemExit("transformers is required: pip install transformers") from exc

from progressive_loader import load_progressive_model


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Evaluate PPL for a model reconstructed from progressive cycles."
    )
    parser.add_argument("--base_model", required=True, help="Base HF model id or path")
    parser.add_argument(
        "--progressive_dir",
        required=True,
        help="Output directory from progressive pruning",
    )
    parser.add_argument(
        "--cycle",
        type=int,
        default=None,
        help="Cycle to load (default: final)",
    )
    parser.add_argument(
        "--dataset",
        action="append",
        default=[],
        help="Evaluation dataset name (repeatable). Defaults to wikitext.",
    )
    parser.add_argument(
        "--dataset_config",
        action="append",
        default=[],
        help="Evaluation dataset config (repeatable or single shared config).",
    )
    parser.add_argument(
        "--dataset_split",
        default="test",
        help="Evaluation dataset split (default: test)",
    )
    parser.add_argument(
        "--dataset_text_field",
        default=None,
        help="Evaluation text field override (default: auto-detect)",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=0,
        help="Number of token sequences per dataset (0 = all)",
    )
    parser.add_argument(
        "--seq_len",
        type=int,
        default=2048,
        help="Sequence length for eval",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=4,
        help="Batch size for eval",
    )
    parser.add_argument(
        "--device",
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Device for eval",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--model_family",
        type=str,
        choices=["auto", "llama", "qwen"],
        default="auto",
        help="Model family for BOS handling",
    )
    parser.add_argument(
        "--add_bos",
        type=str,
        choices=["auto", "always", "never"],
        default="auto",
        help="Whether to prepend BOS to each sample",
    )
    parser.add_argument(
        "--max_batches",
        type=int,
        default=None,
        help="Optional max number of eval batches per dataset",
    )
    parser.add_argument(
        "--cache_dir",
        default=None,
        help="Optional datasets cache dir for eval",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="Eval DataLoader workers",
    )
    parser.add_argument(
        "--dtype",
        default="auto",
        choices=["auto", "float32", "float16", "bfloat16"],
        help="Model dtype",
    )
    parser.add_argument(
        "--trust_remote_code",
        action="store_true",
        help="Allow custom model code from hub",
    )
    parser.add_argument(
        "--layer_path",
        default=None,
        help="Override layer attribute path if needed",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    torch.manual_seed(args.seed)

    datasets = args.dataset or ["wikitext"]
    configs = args.dataset_config or ["wikitext-2-raw-v1"]
    configs = ppl_eval._expand_dataset_configs(datasets, configs)

    model = load_progressive_model(
        args.base_model,
        args.progressive_dir,
        cycle=args.cycle,
        device=args.device,
        dtype=args.dtype,
        trust_remote_code=args.trust_remote_code,
        layer_path=args.layer_path,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        args.base_model, trust_remote_code=args.trust_remote_code
    )
    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token

    results = ppl_eval.evaluate_ppl_datasets(
        model,
        tokenizer,
        datasets=datasets,
        configs=configs,
        split=args.dataset_split,
        text_field=args.dataset_text_field,
        num_samples=args.num_samples,
        seq_len=args.seq_len,
        batch_size=args.batch_size,
        device=args.device,
        seed=args.seed,
        shuffle=False,
        model_family=args.model_family,
        add_bos=args.add_bos,
        max_batches=args.max_batches,
        cache_dir=args.cache_dir,
        num_workers=args.num_workers,
    )

    print("Perplexity results:")
    for name, ppl in results.items():
        print(f"{name}: {ppl:.4f}")


if __name__ == "__main__":
    main()