| |
| """Evaluate held-out causal LM loss/perplexity for a token cache.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
|
|
| from heapr.eval.loss import evaluate_token_cache |
| from heapr.model_utils import build_max_memory, load_causal_lm, validate_model_device_placement |
| from heapr.utils import require_torch |
| from heapr.utils import write_json |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model-id", required=True) |
| parser.add_argument("--cache-path", required=True) |
| parser.add_argument("--output-path") |
| parser.add_argument("--revision") |
| parser.add_argument("--batch-size", type=int, default=1) |
| parser.add_argument("--max-chunks", type=int) |
| parser.add_argument("--dtype", default="bfloat16") |
| parser.add_argument("--gpu-memory-per-device") |
| parser.add_argument("--max-gpu-memory") |
| parser.add_argument("--max-cpu-memory") |
| parser.add_argument("--offload-folder") |
| parser.add_argument("--allow-cpu-offload", action="store_true") |
| parser.add_argument("--cache-implementation", default="static") |
| parser.add_argument("--no-cache", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if args.offload_folder and not args.allow_cpu_offload: |
| raise ValueError("--offload-folder requires --allow-cpu-offload") |
| max_memory = build_max_memory( |
| gpu_memory_per_device=args.gpu_memory_per_device, |
| max_gpu_memory=args.max_gpu_memory, |
| max_cpu_memory=args.max_cpu_memory, |
| allow_cpu_offload=args.allow_cpu_offload, |
| ) |
| torch = require_torch() |
| requested_gpu_count = torch.cuda.device_count() if args.gpu_memory_per_device else None |
| model = load_causal_lm( |
| args.model_id, |
| revision=args.revision, |
| dtype=args.dtype, |
| max_memory=max_memory, |
| offload_folder=args.offload_folder if args.allow_cpu_offload else None, |
| use_cache=not args.no_cache, |
| cache_implementation=args.cache_implementation, |
| ) |
| validate_model_device_placement( |
| model, |
| allow_cpu_offload=args.allow_cpu_offload, |
| requested_gpu_count=requested_gpu_count, |
| ) |
| metrics = evaluate_token_cache( |
| model, |
| args.cache_path, |
| batch_size=args.batch_size, |
| max_chunks=args.max_chunks, |
| use_cache=not args.no_cache, |
| cache_implementation=args.cache_implementation, |
| output_path=args.output_path, |
| ) |
| if not args.output_path: |
| write_json("/tmp/heapr_eval_loss.json", metrics) |
| print(metrics) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|