from __future__ import annotations import argparse import json import sys from pathlib import Path REPO_ROOT = Path(__file__).resolve().parents[1] SRC_ROOT = REPO_ROOT / "src" if str(SRC_ROOT) not in sys.path: sys.path.insert(0, str(SRC_ROOT)) from imrnns import evaluate def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Minimal IMRNN checkpoint evaluator.") parser.add_argument("--checkpoint", type=Path, required=True) parser.add_argument("--encoder", required=True) parser.add_argument("--dataset", required=True) parser.add_argument("--cache-dir", type=Path, required=True) parser.add_argument("--datasets-dir", type=Path, default=Path("datasets")) parser.add_argument("--device", default="cpu") parser.add_argument("--k", type=int, default=10) parser.add_argument("--feedback-k", type=int, default=100) return parser.parse_args() def main() -> int: args = parse_args() result = evaluate( encoder=args.encoder, dataset=args.dataset, cache_dir=args.cache_dir, datasets_dir=args.datasets_dir, checkpoint_path=args.checkpoint, device=args.device, feedback_k=args.feedback_k, k=args.k, ) payload = dict(result) payload["checkpoint"] = str(payload["checkpoint"]) print(json.dumps(payload, indent=2)) return 0 if __name__ == "__main__": raise SystemExit(main())