#!/usr/bin/env python3 """Evaluate held-out causal LM loss/perplexity for a token cache.""" from __future__ import annotations import argparse import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from heapr.eval.loss import evaluate_token_cache from heapr.model_utils import build_max_memory, load_causal_lm, validate_model_device_placement from heapr.utils import require_torch from heapr.utils import write_json def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model-id", required=True) parser.add_argument("--cache-path", required=True) parser.add_argument("--output-path") parser.add_argument("--revision") parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--max-chunks", type=int) parser.add_argument("--dtype", default="bfloat16") parser.add_argument("--gpu-memory-per-device") parser.add_argument("--max-gpu-memory") parser.add_argument("--max-cpu-memory") parser.add_argument("--offload-folder") parser.add_argument("--allow-cpu-offload", action="store_true") parser.add_argument("--cache-implementation", default="static") parser.add_argument("--no-cache", action="store_true") return parser.parse_args() def main() -> None: args = parse_args() if args.offload_folder and not args.allow_cpu_offload: raise ValueError("--offload-folder requires --allow-cpu-offload") max_memory = build_max_memory( gpu_memory_per_device=args.gpu_memory_per_device, max_gpu_memory=args.max_gpu_memory, max_cpu_memory=args.max_cpu_memory, allow_cpu_offload=args.allow_cpu_offload, ) torch = require_torch() requested_gpu_count = torch.cuda.device_count() if args.gpu_memory_per_device else None model = load_causal_lm( args.model_id, revision=args.revision, dtype=args.dtype, max_memory=max_memory, offload_folder=args.offload_folder if args.allow_cpu_offload else None, use_cache=not args.no_cache, cache_implementation=args.cache_implementation, ) validate_model_device_placement( model, allow_cpu_offload=args.allow_cpu_offload, requested_gpu_count=requested_gpu_count, ) metrics = evaluate_token_cache( model, args.cache_path, batch_size=args.batch_size, max_chunks=args.max_chunks, use_cache=not args.no_cache, cache_implementation=args.cache_implementation, output_path=args.output_path, ) if not args.output_path: write_json("/tmp/heapr_eval_loss.json", metrics) print(metrics) if __name__ == "__main__": main()