IMRNNs / scripts /minimal_eval.py
yashsaxena21's picture
Upload scripts/minimal_eval.py with huggingface_hub
4b4d203 verified
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
SRC_ROOT = REPO_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
sys.path.insert(0, str(SRC_ROOT))
from imrnns import evaluate
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Minimal IMRNN checkpoint evaluator.")
parser.add_argument("--checkpoint", type=Path, required=True)
parser.add_argument("--encoder", required=True)
parser.add_argument("--dataset", required=True)
parser.add_argument("--cache-dir", type=Path, required=True)
parser.add_argument("--datasets-dir", type=Path, default=Path("datasets"))
parser.add_argument("--device", default="cpu")
parser.add_argument("--k", type=int, default=10)
parser.add_argument("--feedback-k", type=int, default=100)
return parser.parse_args()
def main() -> int:
args = parse_args()
result = evaluate(
encoder=args.encoder,
dataset=args.dataset,
cache_dir=args.cache_dir,
datasets_dir=args.datasets_dir,
checkpoint_path=args.checkpoint,
device=args.device,
feedback_k=args.feedback_k,
k=args.k,
)
payload = dict(result)
payload["checkpoint"] = str(payload["checkpoint"])
print(json.dumps(payload, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())