laguna-martini / scripts /eval_loss.py
nikgeo's picture
Publish Laguna Martini grouped-pruning model card and reproducibility artifacts
6f11713 verified
Raw
History Blame Contribute Delete
2.73 kB
#!/usr/bin/env python3
"""Evaluate held-out causal LM loss/perplexity for a token cache."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from heapr.eval.loss import evaluate_token_cache
from heapr.model_utils import build_max_memory, load_causal_lm, validate_model_device_placement
from heapr.utils import require_torch
from heapr.utils import write_json
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", required=True)
parser.add_argument("--cache-path", required=True)
parser.add_argument("--output-path")
parser.add_argument("--revision")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--max-chunks", type=int)
parser.add_argument("--dtype", default="bfloat16")
parser.add_argument("--gpu-memory-per-device")
parser.add_argument("--max-gpu-memory")
parser.add_argument("--max-cpu-memory")
parser.add_argument("--offload-folder")
parser.add_argument("--allow-cpu-offload", action="store_true")
parser.add_argument("--cache-implementation", default="static")
parser.add_argument("--no-cache", action="store_true")
return parser.parse_args()
def main() -> None:
args = parse_args()
if args.offload_folder and not args.allow_cpu_offload:
raise ValueError("--offload-folder requires --allow-cpu-offload")
max_memory = build_max_memory(
gpu_memory_per_device=args.gpu_memory_per_device,
max_gpu_memory=args.max_gpu_memory,
max_cpu_memory=args.max_cpu_memory,
allow_cpu_offload=args.allow_cpu_offload,
)
torch = require_torch()
requested_gpu_count = torch.cuda.device_count() if args.gpu_memory_per_device else None
model = load_causal_lm(
args.model_id,
revision=args.revision,
dtype=args.dtype,
max_memory=max_memory,
offload_folder=args.offload_folder if args.allow_cpu_offload else None,
use_cache=not args.no_cache,
cache_implementation=args.cache_implementation,
)
validate_model_device_placement(
model,
allow_cpu_offload=args.allow_cpu_offload,
requested_gpu_count=requested_gpu_count,
)
metrics = evaluate_token_cache(
model,
args.cache_path,
batch_size=args.batch_size,
max_chunks=args.max_chunks,
use_cache=not args.no_cache,
cache_implementation=args.cache_implementation,
output_path=args.output_path,
)
if not args.output_path:
write_json("/tmp/heapr_eval_loss.json", metrics)
print(metrics)
if __name__ == "__main__":
main()