File size: 1,446 Bytes
02cf6c1 bd18547 02cf6c1 bd18547 4b4d203 02cf6c1 4b4d203 02cf6c1 4b4d203 02cf6c1 4b4d203 02cf6c1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
SRC_ROOT = REPO_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
sys.path.insert(0, str(SRC_ROOT))
from imrnns import evaluate
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Minimal IMRNN checkpoint evaluator.")
parser.add_argument("--checkpoint", type=Path, required=True)
parser.add_argument("--encoder", required=True)
parser.add_argument("--dataset", required=True)
parser.add_argument("--cache-dir", type=Path, required=True)
parser.add_argument("--datasets-dir", type=Path, default=Path("datasets"))
parser.add_argument("--device", default="cpu")
parser.add_argument("--k", type=int, default=10)
parser.add_argument("--feedback-k", type=int, default=100)
return parser.parse_args()
def main() -> int:
args = parse_args()
result = evaluate(
encoder=args.encoder,
dataset=args.dataset,
cache_dir=args.cache_dir,
datasets_dir=args.datasets_dir,
checkpoint_path=args.checkpoint,
device=args.device,
feedback_k=args.feedback_k,
k=args.k,
)
payload = dict(result)
payload["checkpoint"] = str(payload["checkpoint"])
print(json.dumps(payload, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())
|