#!/usr/bin/env python3 """Evaluate perplexity for a progressive-pruned model assembled from cycles.""" import argparse import torch try: import ppl_eval except Exception as exc: # pragma: no cover - optional dependency raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc try: from transformers import AutoTokenizer except Exception as exc: # pragma: no cover - fail early with clear error raise SystemExit("transformers is required: pip install transformers") from exc from progressive_loader import load_progressive_model def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Evaluate PPL for a model reconstructed from progressive cycles." ) parser.add_argument("--base_model", required=True, help="Base HF model id or path") parser.add_argument( "--progressive_dir", required=True, help="Output directory from progressive pruning", ) parser.add_argument( "--cycle", type=int, default=None, help="Cycle to load (default: final)", ) parser.add_argument( "--dataset", action="append", default=[], help="Evaluation dataset name (repeatable). Defaults to wikitext.", ) parser.add_argument( "--dataset_config", action="append", default=[], help="Evaluation dataset config (repeatable or single shared config).", ) parser.add_argument( "--dataset_split", default="test", help="Evaluation dataset split (default: test)", ) parser.add_argument( "--dataset_text_field", default=None, help="Evaluation text field override (default: auto-detect)", ) parser.add_argument( "--num_samples", type=int, default=0, help="Number of token sequences per dataset (0 = all)", ) parser.add_argument( "--seq_len", type=int, default=2048, help="Sequence length for eval", ) parser.add_argument( "--batch_size", type=int, default=4, help="Batch size for eval", ) parser.add_argument( "--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device for eval", ) parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument( "--model_family", type=str, choices=["auto", "llama", "qwen"], default="auto", help="Model family for BOS handling", ) parser.add_argument( "--add_bos", type=str, choices=["auto", "always", "never"], default="auto", help="Whether to prepend BOS to each sample", ) parser.add_argument( "--max_batches", type=int, default=None, help="Optional max number of eval batches per dataset", ) parser.add_argument( "--cache_dir", default=None, help="Optional datasets cache dir for eval", ) parser.add_argument( "--num_workers", type=int, default=0, help="Eval DataLoader workers", ) parser.add_argument( "--dtype", default="auto", choices=["auto", "float32", "float16", "bfloat16"], help="Model dtype", ) parser.add_argument( "--trust_remote_code", action="store_true", help="Allow custom model code from hub", ) parser.add_argument( "--layer_path", default=None, help="Override layer attribute path if needed", ) return parser.parse_args() def main() -> None: args = parse_args() torch.manual_seed(args.seed) datasets = args.dataset or ["wikitext"] configs = args.dataset_config or ["wikitext-2-raw-v1"] configs = ppl_eval._expand_dataset_configs(datasets, configs) model = load_progressive_model( args.base_model, args.progressive_dir, cycle=args.cycle, device=args.device, dtype=args.dtype, trust_remote_code=args.trust_remote_code, layer_path=args.layer_path, ) tokenizer = AutoTokenizer.from_pretrained( args.base_model, trust_remote_code=args.trust_remote_code ) if tokenizer.pad_token is None and tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token results = ppl_eval.evaluate_ppl_datasets( model, tokenizer, datasets=datasets, configs=configs, split=args.dataset_split, text_field=args.dataset_text_field, num_samples=args.num_samples, seq_len=args.seq_len, batch_size=args.batch_size, device=args.device, seed=args.seed, shuffle=False, model_family=args.model_family, add_bos=args.add_bos, max_batches=args.max_batches, cache_dir=args.cache_dir, num_workers=args.num_workers, ) print("Perplexity results:") for name, ppl in results.items(): print(f"{name}: {ppl:.4f}") if __name__ == "__main__": main()