yashsaxena21 commited on
Commit
3773231
·
verified ·
1 Parent(s): 4b4d203

Upload scripts/hf_end_to_end_demo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/hf_end_to_end_demo.py +14 -22
scripts/hf_end_to_end_demo.py CHANGED
@@ -5,8 +5,6 @@ import json
5
  import sys
6
  from pathlib import Path
7
 
8
- from huggingface_hub import hf_hub_download
9
-
10
  # Allow the demo to import the local IMRNN package directly from this model repo
11
  # without requiring a separate editable installation step.
12
  REPO_ROOT = Path(__file__).resolve().parents[1]
@@ -14,13 +12,12 @@ SRC_ROOT = REPO_ROOT / "src"
14
  if str(SRC_ROOT) not in sys.path:
15
  sys.path.insert(0, str(SRC_ROOT))
16
 
17
- from imrnns.caching import build_cache
18
  from imrnns.beir_data import load_beir_source
19
- from imrnns.checkpoints import load_model
20
  from imrnns.data import load_cached_split
21
  from imrnns.encoders import get_encoder_spec, normalize_encoder_name
22
  from imrnns.evaluation import evaluate_model
23
- from imrnns.model import ModelConfig
24
 
25
 
26
  def parse_args() -> argparse.Namespace:
@@ -60,12 +57,7 @@ def main() -> int:
60
  # Step 1:
61
  # Download the requested IMRNN checkpoint from the public Hugging Face model repo.
62
  # By default, the checkpoint path is inferred from the selected encoder and dataset.
63
- checkpoint_repo_path = args.checkpoint_path or default_hf_checkpoint_path(args.encoder, args.dataset)
64
- checkpoint_local_path = hf_hub_download(
65
- repo_id=args.repo_id,
66
- filename=checkpoint_repo_path,
67
- repo_type="model",
68
- )
69
 
70
  # Step 2:
71
  # Choose where the local BEIR cache should live.
@@ -81,9 +73,9 @@ def main() -> int:
81
  # If the cache for this encoder/dataset pair does not exist yet, build it from scratch.
82
  # This uses the matching base retriever to embed the BEIR dataset locally.
83
  if not (cache_dir / "test" / "embeddings.pt").exists():
84
- build_cache(
85
- dataset_name=args.dataset,
86
- encoder_spec=encoder_spec,
87
  cache_dir=cache_dir,
88
  datasets_dir=datasets_dir,
89
  device=args.device,
@@ -100,12 +92,14 @@ def main() -> int:
100
  cached_test = load_cached_split(cache_dir, "test", source, encoder_spec, args.device)
101
 
102
  # Step 5:
103
- # Load the IMRNN checkpoint on top of the matching base retriever family.
104
  # The checkpoint contains the learned adapter weights used to modulate query and document
105
  # embeddings before ranking.
106
- model, metadata, missing, unexpected = load_model(
107
- checkpoint_path=Path(checkpoint_local_path),
108
- model_config=ModelConfig(input_dim=encoder_spec.embedding_dim),
 
 
109
  device=args.device,
110
  )
111
 
@@ -128,15 +122,13 @@ def main() -> int:
128
  json.dumps(
129
  {
130
  "repo_id": args.repo_id,
131
- "checkpoint": checkpoint_repo_path,
132
- "local_checkpoint": checkpoint_local_path,
133
  "encoder": args.encoder,
134
  "dataset": args.dataset,
135
  "cache_dir": str(cache_dir),
136
  "metrics": metrics,
137
  "metadata": metadata,
138
- "missing_keys": missing,
139
- "unexpected_keys": unexpected,
140
  },
141
  indent=2,
142
  )
 
5
  import sys
6
  from pathlib import Path
7
 
 
 
8
  # Allow the demo to import the local IMRNN package directly from this model repo
9
  # without requiring a separate editable installation step.
10
  REPO_ROOT = Path(__file__).resolve().parents[1]
 
12
  if str(SRC_ROOT) not in sys.path:
13
  sys.path.insert(0, str(SRC_ROOT))
14
 
15
+ from imrnns import cache_embeddings
16
  from imrnns.beir_data import load_beir_source
 
17
  from imrnns.data import load_cached_split
18
  from imrnns.encoders import get_encoder_spec, normalize_encoder_name
19
  from imrnns.evaluation import evaluate_model
20
+ from imrnns.hub import load_pretrained
21
 
22
 
23
  def parse_args() -> argparse.Namespace:
 
57
  # Step 1:
58
  # Download the requested IMRNN checkpoint from the public Hugging Face model repo.
59
  # By default, the checkpoint path is inferred from the selected encoder and dataset.
60
+ checkpoint_path = args.checkpoint_path or default_hf_checkpoint_path(args.encoder, args.dataset)
 
 
 
 
 
61
 
62
  # Step 2:
63
  # Choose where the local BEIR cache should live.
 
73
  # If the cache for this encoder/dataset pair does not exist yet, build it from scratch.
74
  # This uses the matching base retriever to embed the BEIR dataset locally.
75
  if not (cache_dir / "test" / "embeddings.pt").exists():
76
+ cache_embeddings(
77
+ encoder=args.encoder,
78
+ dataset=args.dataset,
79
  cache_dir=cache_dir,
80
  datasets_dir=datasets_dir,
81
  device=args.device,
 
92
  cached_test = load_cached_split(cache_dir, "test", source, encoder_spec, args.device)
93
 
94
  # Step 5:
95
+ # Load the IMRNN checkpoint from the Hugging Face repo on top of the matching base retriever family.
96
  # The checkpoint contains the learned adapter weights used to modulate query and document
97
  # embeddings before ranking.
98
+ model, metadata, _ = load_pretrained(
99
+ encoder=args.encoder,
100
+ dataset=args.dataset,
101
+ repo_id=args.repo_id,
102
+ checkpoint_filename=checkpoint_path,
103
  device=args.device,
104
  )
105
 
 
122
  json.dumps(
123
  {
124
  "repo_id": args.repo_id,
125
+ "checkpoint": checkpoint_path,
126
+ "local_checkpoint": metadata.get("downloaded_checkpoint"),
127
  "encoder": args.encoder,
128
  "dataset": args.dataset,
129
  "cache_dir": str(cache_dir),
130
  "metrics": metrics,
131
  "metadata": metadata,
 
 
132
  },
133
  indent=2,
134
  )