temp_ss / src /ppl_eval_progressive.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
"""Evaluate perplexity for a progressive-pruned model assembled from cycles."""
import argparse
import torch
try:
import ppl_eval
except Exception as exc: # pragma: no cover - optional dependency
raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc
try:
from transformers import AutoTokenizer
except Exception as exc: # pragma: no cover - fail early with clear error
raise SystemExit("transformers is required: pip install transformers") from exc
from progressive_loader import load_progressive_model
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Evaluate PPL for a model reconstructed from progressive cycles."
)
parser.add_argument("--base_model", required=True, help="Base HF model id or path")
parser.add_argument(
"--progressive_dir",
required=True,
help="Output directory from progressive pruning",
)
parser.add_argument(
"--cycle",
type=int,
default=None,
help="Cycle to load (default: final)",
)
parser.add_argument(
"--dataset",
action="append",
default=[],
help="Evaluation dataset name (repeatable). Defaults to wikitext.",
)
parser.add_argument(
"--dataset_config",
action="append",
default=[],
help="Evaluation dataset config (repeatable or single shared config).",
)
parser.add_argument(
"--dataset_split",
default="test",
help="Evaluation dataset split (default: test)",
)
parser.add_argument(
"--dataset_text_field",
default=None,
help="Evaluation text field override (default: auto-detect)",
)
parser.add_argument(
"--num_samples",
type=int,
default=0,
help="Number of token sequences per dataset (0 = all)",
)
parser.add_argument(
"--seq_len",
type=int,
default=2048,
help="Sequence length for eval",
)
parser.add_argument(
"--batch_size",
type=int,
default=4,
help="Batch size for eval",
)
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device for eval",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--model_family",
type=str,
choices=["auto", "llama", "qwen"],
default="auto",
help="Model family for BOS handling",
)
parser.add_argument(
"--add_bos",
type=str,
choices=["auto", "always", "never"],
default="auto",
help="Whether to prepend BOS to each sample",
)
parser.add_argument(
"--max_batches",
type=int,
default=None,
help="Optional max number of eval batches per dataset",
)
parser.add_argument(
"--cache_dir",
default=None,
help="Optional datasets cache dir for eval",
)
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="Eval DataLoader workers",
)
parser.add_argument(
"--dtype",
default="auto",
choices=["auto", "float32", "float16", "bfloat16"],
help="Model dtype",
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
help="Allow custom model code from hub",
)
parser.add_argument(
"--layer_path",
default=None,
help="Override layer attribute path if needed",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
datasets = args.dataset or ["wikitext"]
configs = args.dataset_config or ["wikitext-2-raw-v1"]
configs = ppl_eval._expand_dataset_configs(datasets, configs)
model = load_progressive_model(
args.base_model,
args.progressive_dir,
cycle=args.cycle,
device=args.device,
dtype=args.dtype,
trust_remote_code=args.trust_remote_code,
layer_path=args.layer_path,
)
tokenizer = AutoTokenizer.from_pretrained(
args.base_model, trust_remote_code=args.trust_remote_code
)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
results = ppl_eval.evaluate_ppl_datasets(
model,
tokenizer,
datasets=datasets,
configs=configs,
split=args.dataset_split,
text_field=args.dataset_text_field,
num_samples=args.num_samples,
seq_len=args.seq_len,
batch_size=args.batch_size,
device=args.device,
seed=args.seed,
shuffle=False,
model_family=args.model_family,
add_bos=args.add_bos,
max_batches=args.max_batches,
cache_dir=args.cache_dir,
num_workers=args.num_workers,
)
print("Perplexity results:")
for name, ppl in results.items():
print(f"{name}: {ppl:.4f}")
if __name__ == "__main__":
main()