| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| |
| |
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| SRC_ROOT = REPO_ROOT / "src" |
| if str(SRC_ROOT) not in sys.path: |
| sys.path.insert(0, str(SRC_ROOT)) |
|
|
| from imrnns import cache_embeddings |
| from imrnns.beir_data import load_beir_source |
| from imrnns.data import load_cached_split |
| from imrnns.encoders import get_encoder_spec, normalize_encoder_name |
| from imrnns.evaluation import evaluate_model |
| from imrnns.hub import load_pretrained |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="End-to-end IMRNN demo: download checkpoint from Hugging Face, build cache if needed, and evaluate." |
| ) |
| parser.add_argument("--repo-id", default="yashsaxena21/IMRNNs") |
| parser.add_argument("--encoder", required=True, help="minilm or e5") |
| parser.add_argument("--dataset", required=True, help="BEIR dataset name") |
| parser.add_argument("--checkpoint-path", help="Optional path inside the HF repo") |
| parser.add_argument("--device", default="cpu") |
| parser.add_argument("--k", type=int, default=10) |
| parser.add_argument("--feedback-k", type=int, default=100) |
| parser.add_argument("--batch-size", type=int, default=64) |
| parser.add_argument("--num-negatives", type=int, default=20) |
| parser.add_argument("--negative-pool", type=int, default=200) |
| parser.add_argument("--max-queries", type=int, default=None) |
| parser.add_argument("--cache-dir", type=Path, default=None) |
| parser.add_argument("--datasets-dir", type=Path, default=Path("datasets")) |
| return parser.parse_args() |
|
|
|
|
| def default_hf_checkpoint_path(encoder: str, dataset: str) -> str: |
| normalized = normalize_encoder_name(encoder) |
| display = "minilm" if normalized == "mini" else normalized |
| return f"checkpoints/pretrained/{display}/imrnns-{display}-{dataset}.pt" |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
|
|
| |
| |
| encoder_spec = get_encoder_spec(args.encoder) |
| normalized_encoder = "minilm" if encoder_spec.key == "mini" else encoder_spec.key |
|
|
| |
| |
| |
| checkpoint_path = args.checkpoint_path or default_hf_checkpoint_path(args.encoder, args.dataset) |
|
|
| |
| |
| |
| |
| |
| |
| |
| cache_dir = args.cache_dir or Path("demo_cache") / f"cache_{normalized_encoder}_{args.dataset}" |
| datasets_dir = args.datasets_dir |
|
|
| |
| |
| |
| if not (cache_dir / "test" / "embeddings.pt").exists(): |
| cache_embeddings( |
| encoder=args.encoder, |
| dataset=args.dataset, |
| cache_dir=cache_dir, |
| datasets_dir=datasets_dir, |
| device=args.device, |
| batch_size=args.batch_size, |
| num_negatives=args.num_negatives, |
| negative_pool=args.negative_pool, |
| max_queries=args.max_queries, |
| ) |
|
|
| |
| |
| |
| source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries) |
| cached_test = load_cached_split(cache_dir, "test", source, encoder_spec, args.device) |
|
|
| |
| |
| |
| |
| model, metadata, _ = load_pretrained( |
| encoder=args.encoder, |
| dataset=args.dataset, |
| repo_id=args.repo_id, |
| checkpoint_filename=checkpoint_path, |
| device=args.device, |
| ) |
|
|
| |
| |
| |
| metrics = evaluate_model( |
| model=model, |
| cached_split=cached_test, |
| device=args.device, |
| feedback_k=args.feedback_k, |
| ranking_k=args.k, |
| k_values=[args.k], |
| ) |
|
|
| |
| |
| |
| print( |
| json.dumps( |
| { |
| "repo_id": args.repo_id, |
| "checkpoint": checkpoint_path, |
| "local_checkpoint": metadata.get("downloaded_checkpoint"), |
| "encoder": args.encoder, |
| "dataset": args.dataset, |
| "cache_dir": str(cache_dir), |
| "metrics": metrics, |
| "metadata": metadata, |
| }, |
| indent=2, |
| ) |
| ) |
|
|
| |
| |
| |
| |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|