yashsaxena21 commited on
Commit
14e3943
·
verified ·
1 Parent(s): 7a61837

Upload folder using huggingface_hub

Browse files
src/imrnns/api.py CHANGED
@@ -7,7 +7,7 @@ from .beir_data import load_beir_source
7
  from .caching import build_cache
8
  from .checkpoints import default_checkpoint_name, load_model, save_checkpoint
9
  from .data import ContrastiveCachedDataset, load_cached_split
10
- from .encoders import resolve_encoder_spec
11
  from .evaluation import evaluate_model
12
  from .model import IMRNN, ModelConfig
13
  from .training import TrainingConfig, train_model
@@ -124,7 +124,7 @@ def train(
124
  k_values=[k],
125
  )
126
 
127
- checkpoint_stem = encoder or encoder_spec.key
128
  checkpoint_path = output_dir / default_checkpoint_name(checkpoint_stem, dataset)
129
  metadata = {
130
  "encoder": checkpoint_stem,
 
7
  from .caching import build_cache
8
  from .checkpoints import default_checkpoint_name, load_model, save_checkpoint
9
  from .data import ContrastiveCachedDataset, load_cached_split
10
+ from .encoders import encoder_storage_key, resolve_encoder_spec
11
  from .evaluation import evaluate_model
12
  from .model import IMRNN, ModelConfig
13
  from .training import TrainingConfig, train_model
 
124
  k_values=[k],
125
  )
126
 
127
+ checkpoint_stem = encoder_storage_key(encoder or encoder_spec.key)
128
  checkpoint_path = output_dir / default_checkpoint_name(checkpoint_stem, dataset)
129
  metadata = {
130
  "encoder": checkpoint_stem,
src/imrnns/assets.py CHANGED
@@ -5,7 +5,7 @@ from dataclasses import dataclass
5
  from pathlib import Path
6
  from typing import Optional
7
 
8
- from .encoders import normalize_encoder_name
9
 
10
 
11
  @dataclass(frozen=True)
@@ -43,10 +43,10 @@ def discover_cached_embeddings(assets_root: Path) -> list[AssetMatch]:
43
  assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
44
  continue
45
 
46
- match = re.fullmatch(r"cache_(mini|minilm|e5|mpnet)_(.+)", entry.name)
47
  if match:
48
  encoder, dataset = match.groups()
49
- encoder = normalize_encoder_name(encoder)
50
  assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
51
  return assets
52
 
@@ -68,34 +68,39 @@ def discover_repo_checkpoints(repo_root: Path) -> list[AssetMatch]:
68
  if not base_dir.exists():
69
  return assets
70
  for entry in sorted(base_dir.rglob("*.pt")):
71
- match = re.fullmatch(r"imrnns-(minilm|e5)-(.+)\.pt", entry.name)
72
- if not match:
 
73
  continue
74
- encoder, dataset = match.groups()
75
- if encoder == "minilm":
76
- encoder = "mini"
77
  assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
78
  return assets
79
 
80
 
81
  def resolve_cache_dir(assets_root: Path, encoder: str, dataset: str) -> Path:
82
- encoder = normalize_encoder_name(encoder)
83
  dataset = dataset.lower()
84
  for asset in discover_cached_embeddings(assets_root):
85
  if asset.encoder == encoder and asset.dataset.lower() == dataset:
86
  return asset.path
 
 
 
87
  raise FileNotFoundError(
88
  f"No cached embeddings found for encoder='{encoder}' dataset='{dataset}' under {assets_root}"
89
  )
90
 
91
 
92
  def resolve_checkpoint_path(assets_root: Path, encoder: str, dataset: str) -> Optional[Path]:
93
- encoder = normalize_encoder_name(encoder)
94
  dataset = dataset.lower()
95
  for asset in discover_repo_checkpoints(package_root()):
96
  if asset.encoder == encoder and asset.dataset.lower() == dataset:
97
  return asset.path
98
  for asset in discover_checkpoints(assets_root):
99
- if asset.encoder == encoder and asset.dataset.lower() == dataset:
100
  return asset.path
 
 
 
101
  return None
 
5
  from pathlib import Path
6
  from typing import Optional
7
 
8
+ from .encoders import encoder_storage_key, normalize_encoder_name
9
 
10
 
11
  @dataclass(frozen=True)
 
43
  assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
44
  continue
45
 
46
+ match = re.fullmatch(r"cache_(.+)_(.+)", entry.name)
47
  if match:
48
  encoder, dataset = match.groups()
49
+ encoder = encoder_storage_key(encoder)
50
  assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
51
  return assets
52
 
 
68
  if not base_dir.exists():
69
  return assets
70
  for entry in sorted(base_dir.rglob("*.pt")):
71
+ encoder = encoder_storage_key(entry.parent.name)
72
+ prefix = f"imrnns-{entry.parent.name}-"
73
+ if not entry.name.startswith(prefix) or not entry.name.endswith(".pt"):
74
  continue
75
+ dataset = entry.name.removeprefix(prefix).removesuffix(".pt")
 
 
76
  assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
77
  return assets
78
 
79
 
80
  def resolve_cache_dir(assets_root: Path, encoder: str, dataset: str) -> Path:
81
+ encoder = encoder_storage_key(encoder)
82
  dataset = dataset.lower()
83
  for asset in discover_cached_embeddings(assets_root):
84
  if asset.encoder == encoder and asset.dataset.lower() == dataset:
85
  return asset.path
86
+ direct = assets_root / f"cache_{encoder}_{dataset}"
87
+ if direct.exists():
88
+ return direct
89
  raise FileNotFoundError(
90
  f"No cached embeddings found for encoder='{encoder}' dataset='{dataset}' under {assets_root}"
91
  )
92
 
93
 
94
  def resolve_checkpoint_path(assets_root: Path, encoder: str, dataset: str) -> Optional[Path]:
95
+ encoder = encoder_storage_key(encoder)
96
  dataset = dataset.lower()
97
  for asset in discover_repo_checkpoints(package_root()):
98
  if asset.encoder == encoder and asset.dataset.lower() == dataset:
99
  return asset.path
100
  for asset in discover_checkpoints(assets_root):
101
+ if encoder_storage_key(asset.encoder) == encoder and asset.dataset.lower() == dataset:
102
  return asset.path
103
+ direct = assets_root / f"imrnns-{encoder}-{dataset}.pt"
104
+ if direct.exists():
105
+ return direct
106
  return None
src/imrnns/checkpoints.py CHANGED
@@ -6,14 +6,12 @@ from typing import Any
6
 
7
  import torch
8
 
9
- from .encoders import normalize_encoder_name
10
  from .model import IMRNN, ModelConfig
11
 
12
 
13
  def default_checkpoint_name(encoder: str, dataset: str) -> str:
14
- normalized = normalize_encoder_name(encoder)
15
- display = "minilm" if normalized == "mini" else normalized
16
- return f"imrnns-{display}-{dataset}.pt"
17
 
18
 
19
  def sanitize_legacy_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
 
6
 
7
  import torch
8
 
9
+ from .encoders import encoder_storage_key, normalize_encoder_name
10
  from .model import IMRNN, ModelConfig
11
 
12
 
13
  def default_checkpoint_name(encoder: str, dataset: str) -> str:
14
+ return f"imrnns-{encoder_storage_key(encoder)}-{dataset}.pt"
 
 
15
 
16
 
17
  def sanitize_legacy_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
src/imrnns/cli.py CHANGED
@@ -17,7 +17,7 @@ from .beir_data import load_beir_source
17
  from .caching import build_cache
18
  from .checkpoints import default_checkpoint_name, load_model, save_checkpoint
19
  from .data import ContrastiveCachedDataset, load_cached_split
20
- from .encoders import normalize_encoder_name, resolve_encoder_spec
21
  from .evaluation import evaluate_model
22
  from .model import IMRNN, ModelConfig
23
  from .training import TrainingConfig, train_model
@@ -25,6 +25,7 @@ from .training import TrainingConfig, train_model
25
 
26
  def _add_common_args(parser: argparse.ArgumentParser) -> None:
27
  parser.add_argument("--assets-root", type=Path, default=default_assets_root())
 
28
  parser.add_argument("--encoder")
29
  parser.add_argument("--encoder-model-name")
30
  parser.add_argument("--embedding-dim", type=int)
@@ -46,9 +47,8 @@ def _resolve_encoder_spec(args: argparse.Namespace):
46
 
47
  def _encoder_label(args: argparse.Namespace, encoder_spec) -> str:
48
  if args.encoder:
49
- normalized = normalize_encoder_name(args.encoder)
50
- return "minilm" if normalized == "mini" else normalized
51
- return encoder_spec.key.replace("/", "-")
52
 
53
 
54
  def _command_list_assets(args: argparse.Namespace) -> int:
@@ -76,7 +76,7 @@ def _load_training_inputs(args: argparse.Namespace):
76
  encoder_spec = _resolve_encoder_spec(args)
77
  encoder_label = _encoder_label(args, encoder_spec)
78
  cache_dir = args.cache_dir or resolve_cache_dir(args.assets_root, encoder_label, args.dataset)
79
- datasets_dir = args.assets_root / "datasets"
80
  beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries)
81
  train_split = load_cached_split(cache_dir, "train", beir_source, encoder_spec, args.device)
82
  val_split = load_cached_split(cache_dir, "val", beir_source, encoder_spec, args.device)
@@ -96,7 +96,7 @@ def _command_cache(args: argparse.Namespace) -> int:
96
  dataset_name=args.dataset,
97
  encoder_spec=encoder_spec,
98
  cache_dir=cache_dir,
99
- datasets_dir=args.assets_root / "datasets",
100
  device=args.device,
101
  batch_size=args.batch_size,
102
  num_negatives=args.num_negatives,
@@ -194,7 +194,7 @@ def _command_evaluate(args: argparse.Namespace) -> int:
194
  f"No checkpoint found for encoder='{encoder_label}' dataset='{args.dataset}'. Provide --checkpoint."
195
  )
196
 
197
- datasets_dir = args.assets_root / "datasets"
198
  beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries)
199
  test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, args.device)
200
  model, metadata, missing, unexpected = load_model(
 
17
  from .caching import build_cache
18
  from .checkpoints import default_checkpoint_name, load_model, save_checkpoint
19
  from .data import ContrastiveCachedDataset, load_cached_split
20
+ from .encoders import encoder_storage_key, normalize_encoder_name, resolve_encoder_spec
21
  from .evaluation import evaluate_model
22
  from .model import IMRNN, ModelConfig
23
  from .training import TrainingConfig, train_model
 
25
 
26
  def _add_common_args(parser: argparse.ArgumentParser) -> None:
27
  parser.add_argument("--assets-root", type=Path, default=default_assets_root())
28
+ parser.add_argument("--datasets-dir", type=Path)
29
  parser.add_argument("--encoder")
30
  parser.add_argument("--encoder-model-name")
31
  parser.add_argument("--embedding-dim", type=int)
 
47
 
48
  def _encoder_label(args: argparse.Namespace, encoder_spec) -> str:
49
  if args.encoder:
50
+ return encoder_storage_key(args.encoder)
51
+ return encoder_storage_key(encoder_spec.key)
 
52
 
53
 
54
  def _command_list_assets(args: argparse.Namespace) -> int:
 
76
  encoder_spec = _resolve_encoder_spec(args)
77
  encoder_label = _encoder_label(args, encoder_spec)
78
  cache_dir = args.cache_dir or resolve_cache_dir(args.assets_root, encoder_label, args.dataset)
79
+ datasets_dir = args.datasets_dir or (args.assets_root / "datasets")
80
  beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries)
81
  train_split = load_cached_split(cache_dir, "train", beir_source, encoder_spec, args.device)
82
  val_split = load_cached_split(cache_dir, "val", beir_source, encoder_spec, args.device)
 
96
  dataset_name=args.dataset,
97
  encoder_spec=encoder_spec,
98
  cache_dir=cache_dir,
99
+ datasets_dir=args.datasets_dir or (args.assets_root / "datasets"),
100
  device=args.device,
101
  batch_size=args.batch_size,
102
  num_negatives=args.num_negatives,
 
194
  f"No checkpoint found for encoder='{encoder_label}' dataset='{args.dataset}'. Provide --checkpoint."
195
  )
196
 
197
+ datasets_dir = args.datasets_dir or (args.assets_root / "datasets")
198
  beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries)
199
  test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, args.device)
200
  model, metadata, missing, unexpected = load_model(
src/imrnns/encoders.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  from dataclasses import dataclass
4
  from typing import Optional
 
5
 
6
 
7
  @dataclass(frozen=True)
@@ -47,6 +48,13 @@ def normalize_encoder_name(name: str) -> str:
47
  return aliases.get(key, key)
48
 
49
 
 
 
 
 
 
 
 
50
  def get_encoder_spec(name: str) -> EncoderSpec:
51
  key = normalize_encoder_name(name)
52
  if key not in ENCODER_SPECS:
 
2
 
3
  from dataclasses import dataclass
4
  from typing import Optional
5
+ import re
6
 
7
 
8
  @dataclass(frozen=True)
 
48
  return aliases.get(key, key)
49
 
50
 
51
+ def encoder_storage_key(name: str) -> str:
52
+ normalized = normalize_encoder_name(name)
53
+ if normalized == "mini":
54
+ return "minilm"
55
+ return re.sub(r"[^a-z0-9._-]+", "-", normalized.lower()).strip("-")
56
+
57
+
58
  def get_encoder_spec(name: str) -> EncoderSpec:
59
  key = normalize_encoder_name(name)
60
  if key not in ENCODER_SPECS: