Upload scripts/hf_end_to_end_demo.py with huggingface_hub
Browse files- scripts/hf_end_to_end_demo.py +14 -22
scripts/hf_end_to_end_demo.py
CHANGED
|
@@ -5,8 +5,6 @@ import json
|
|
| 5 |
import sys
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
-
from huggingface_hub import hf_hub_download
|
| 9 |
-
|
| 10 |
# Allow the demo to import the local IMRNN package directly from this model repo
|
| 11 |
# without requiring a separate editable installation step.
|
| 12 |
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
@@ -14,13 +12,12 @@ SRC_ROOT = REPO_ROOT / "src"
|
|
| 14 |
if str(SRC_ROOT) not in sys.path:
|
| 15 |
sys.path.insert(0, str(SRC_ROOT))
|
| 16 |
|
| 17 |
-
from imrnns
|
| 18 |
from imrnns.beir_data import load_beir_source
|
| 19 |
-
from imrnns.checkpoints import load_model
|
| 20 |
from imrnns.data import load_cached_split
|
| 21 |
from imrnns.encoders import get_encoder_spec, normalize_encoder_name
|
| 22 |
from imrnns.evaluation import evaluate_model
|
| 23 |
-
from imrnns.
|
| 24 |
|
| 25 |
|
| 26 |
def parse_args() -> argparse.Namespace:
|
|
@@ -60,12 +57,7 @@ def main() -> int:
|
|
| 60 |
# Step 1:
|
| 61 |
# Download the requested IMRNN checkpoint from the public Hugging Face model repo.
|
| 62 |
# By default, the checkpoint path is inferred from the selected encoder and dataset.
|
| 63 |
-
|
| 64 |
-
checkpoint_local_path = hf_hub_download(
|
| 65 |
-
repo_id=args.repo_id,
|
| 66 |
-
filename=checkpoint_repo_path,
|
| 67 |
-
repo_type="model",
|
| 68 |
-
)
|
| 69 |
|
| 70 |
# Step 2:
|
| 71 |
# Choose where the local BEIR cache should live.
|
|
@@ -81,9 +73,9 @@ def main() -> int:
|
|
| 81 |
# If the cache for this encoder/dataset pair does not exist yet, build it from scratch.
|
| 82 |
# This uses the matching base retriever to embed the BEIR dataset locally.
|
| 83 |
if not (cache_dir / "test" / "embeddings.pt").exists():
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
cache_dir=cache_dir,
|
| 88 |
datasets_dir=datasets_dir,
|
| 89 |
device=args.device,
|
|
@@ -100,12 +92,14 @@ def main() -> int:
|
|
| 100 |
cached_test = load_cached_split(cache_dir, "test", source, encoder_spec, args.device)
|
| 101 |
|
| 102 |
# Step 5:
|
| 103 |
-
# Load the IMRNN checkpoint on top of the matching base retriever family.
|
| 104 |
# The checkpoint contains the learned adapter weights used to modulate query and document
|
| 105 |
# embeddings before ranking.
|
| 106 |
-
model, metadata,
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
device=args.device,
|
| 110 |
)
|
| 111 |
|
|
@@ -128,15 +122,13 @@ def main() -> int:
|
|
| 128 |
json.dumps(
|
| 129 |
{
|
| 130 |
"repo_id": args.repo_id,
|
| 131 |
-
"checkpoint":
|
| 132 |
-
"local_checkpoint":
|
| 133 |
"encoder": args.encoder,
|
| 134 |
"dataset": args.dataset,
|
| 135 |
"cache_dir": str(cache_dir),
|
| 136 |
"metrics": metrics,
|
| 137 |
"metadata": metadata,
|
| 138 |
-
"missing_keys": missing,
|
| 139 |
-
"unexpected_keys": unexpected,
|
| 140 |
},
|
| 141 |
indent=2,
|
| 142 |
)
|
|
|
|
| 5 |
import sys
|
| 6 |
from pathlib import Path
|
| 7 |
|
|
|
|
|
|
|
| 8 |
# Allow the demo to import the local IMRNN package directly from this model repo
|
| 9 |
# without requiring a separate editable installation step.
|
| 10 |
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
| 12 |
if str(SRC_ROOT) not in sys.path:
|
| 13 |
sys.path.insert(0, str(SRC_ROOT))
|
| 14 |
|
| 15 |
+
from imrnns import cache_embeddings
|
| 16 |
from imrnns.beir_data import load_beir_source
|
|
|
|
| 17 |
from imrnns.data import load_cached_split
|
| 18 |
from imrnns.encoders import get_encoder_spec, normalize_encoder_name
|
| 19 |
from imrnns.evaluation import evaluate_model
|
| 20 |
+
from imrnns.hub import load_pretrained
|
| 21 |
|
| 22 |
|
| 23 |
def parse_args() -> argparse.Namespace:
|
|
|
|
| 57 |
# Step 1:
|
| 58 |
# Download the requested IMRNN checkpoint from the public Hugging Face model repo.
|
| 59 |
# By default, the checkpoint path is inferred from the selected encoder and dataset.
|
| 60 |
+
checkpoint_path = args.checkpoint_path or default_hf_checkpoint_path(args.encoder, args.dataset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Step 2:
|
| 63 |
# Choose where the local BEIR cache should live.
|
|
|
|
| 73 |
# If the cache for this encoder/dataset pair does not exist yet, build it from scratch.
|
| 74 |
# This uses the matching base retriever to embed the BEIR dataset locally.
|
| 75 |
if not (cache_dir / "test" / "embeddings.pt").exists():
|
| 76 |
+
cache_embeddings(
|
| 77 |
+
encoder=args.encoder,
|
| 78 |
+
dataset=args.dataset,
|
| 79 |
cache_dir=cache_dir,
|
| 80 |
datasets_dir=datasets_dir,
|
| 81 |
device=args.device,
|
|
|
|
| 92 |
cached_test = load_cached_split(cache_dir, "test", source, encoder_spec, args.device)
|
| 93 |
|
| 94 |
# Step 5:
|
| 95 |
+
# Load the IMRNN checkpoint from the Hugging Face repo on top of the matching base retriever family.
|
| 96 |
# The checkpoint contains the learned adapter weights used to modulate query and document
|
| 97 |
# embeddings before ranking.
|
| 98 |
+
model, metadata, _ = load_pretrained(
|
| 99 |
+
encoder=args.encoder,
|
| 100 |
+
dataset=args.dataset,
|
| 101 |
+
repo_id=args.repo_id,
|
| 102 |
+
checkpoint_filename=checkpoint_path,
|
| 103 |
device=args.device,
|
| 104 |
)
|
| 105 |
|
|
|
|
| 122 |
json.dumps(
|
| 123 |
{
|
| 124 |
"repo_id": args.repo_id,
|
| 125 |
+
"checkpoint": checkpoint_path,
|
| 126 |
+
"local_checkpoint": metadata.get("downloaded_checkpoint"),
|
| 127 |
"encoder": args.encoder,
|
| 128 |
"dataset": args.dataset,
|
| 129 |
"cache_dir": str(cache_dir),
|
| 130 |
"metrics": metrics,
|
| 131 |
"metadata": metadata,
|
|
|
|
|
|
|
| 132 |
},
|
| 133 |
indent=2,
|
| 134 |
)
|