Upload folder using huggingface_hub
Browse files- src/imrnns/api.py +2 -2
- src/imrnns/assets.py +16 -11
- src/imrnns/checkpoints.py +2 -4
- src/imrnns/cli.py +7 -7
- src/imrnns/encoders.py +8 -0
src/imrnns/api.py
CHANGED
|
@@ -7,7 +7,7 @@ from .beir_data import load_beir_source
|
|
| 7 |
from .caching import build_cache
|
| 8 |
from .checkpoints import default_checkpoint_name, load_model, save_checkpoint
|
| 9 |
from .data import ContrastiveCachedDataset, load_cached_split
|
| 10 |
-
from .encoders import resolve_encoder_spec
|
| 11 |
from .evaluation import evaluate_model
|
| 12 |
from .model import IMRNN, ModelConfig
|
| 13 |
from .training import TrainingConfig, train_model
|
|
@@ -124,7 +124,7 @@ def train(
|
|
| 124 |
k_values=[k],
|
| 125 |
)
|
| 126 |
|
| 127 |
-
checkpoint_stem = encoder or encoder_spec.key
|
| 128 |
checkpoint_path = output_dir / default_checkpoint_name(checkpoint_stem, dataset)
|
| 129 |
metadata = {
|
| 130 |
"encoder": checkpoint_stem,
|
|
|
|
| 7 |
from .caching import build_cache
|
| 8 |
from .checkpoints import default_checkpoint_name, load_model, save_checkpoint
|
| 9 |
from .data import ContrastiveCachedDataset, load_cached_split
|
| 10 |
+
from .encoders import encoder_storage_key, resolve_encoder_spec
|
| 11 |
from .evaluation import evaluate_model
|
| 12 |
from .model import IMRNN, ModelConfig
|
| 13 |
from .training import TrainingConfig, train_model
|
|
|
|
| 124 |
k_values=[k],
|
| 125 |
)
|
| 126 |
|
| 127 |
+
checkpoint_stem = encoder_storage_key(encoder or encoder_spec.key)
|
| 128 |
checkpoint_path = output_dir / default_checkpoint_name(checkpoint_stem, dataset)
|
| 129 |
metadata = {
|
| 130 |
"encoder": checkpoint_stem,
|
src/imrnns/assets.py
CHANGED
|
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Optional
|
| 7 |
|
| 8 |
-
from .encoders import normalize_encoder_name
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass(frozen=True)
|
|
@@ -43,10 +43,10 @@ def discover_cached_embeddings(assets_root: Path) -> list[AssetMatch]:
|
|
| 43 |
assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
|
| 44 |
continue
|
| 45 |
|
| 46 |
-
match = re.fullmatch(r"cache_(
|
| 47 |
if match:
|
| 48 |
encoder, dataset = match.groups()
|
| 49 |
-
encoder =
|
| 50 |
assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
|
| 51 |
return assets
|
| 52 |
|
|
@@ -68,34 +68,39 @@ def discover_repo_checkpoints(repo_root: Path) -> list[AssetMatch]:
|
|
| 68 |
if not base_dir.exists():
|
| 69 |
return assets
|
| 70 |
for entry in sorted(base_dir.rglob("*.pt")):
|
| 71 |
-
|
| 72 |
-
|
|
|
|
| 73 |
continue
|
| 74 |
-
|
| 75 |
-
if encoder == "minilm":
|
| 76 |
-
encoder = "mini"
|
| 77 |
assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
|
| 78 |
return assets
|
| 79 |
|
| 80 |
|
| 81 |
def resolve_cache_dir(assets_root: Path, encoder: str, dataset: str) -> Path:
|
| 82 |
-
encoder =
|
| 83 |
dataset = dataset.lower()
|
| 84 |
for asset in discover_cached_embeddings(assets_root):
|
| 85 |
if asset.encoder == encoder and asset.dataset.lower() == dataset:
|
| 86 |
return asset.path
|
|
|
|
|
|
|
|
|
|
| 87 |
raise FileNotFoundError(
|
| 88 |
f"No cached embeddings found for encoder='{encoder}' dataset='{dataset}' under {assets_root}"
|
| 89 |
)
|
| 90 |
|
| 91 |
|
| 92 |
def resolve_checkpoint_path(assets_root: Path, encoder: str, dataset: str) -> Optional[Path]:
|
| 93 |
-
encoder =
|
| 94 |
dataset = dataset.lower()
|
| 95 |
for asset in discover_repo_checkpoints(package_root()):
|
| 96 |
if asset.encoder == encoder and asset.dataset.lower() == dataset:
|
| 97 |
return asset.path
|
| 98 |
for asset in discover_checkpoints(assets_root):
|
| 99 |
-
if asset.encoder == encoder and asset.dataset.lower() == dataset:
|
| 100 |
return asset.path
|
|
|
|
|
|
|
|
|
|
| 101 |
return None
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Optional
|
| 7 |
|
| 8 |
+
from .encoders import encoder_storage_key, normalize_encoder_name
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass(frozen=True)
|
|
|
|
| 43 |
assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
|
| 44 |
continue
|
| 45 |
|
| 46 |
+
match = re.fullmatch(r"cache_(.+)_(.+)", entry.name)
|
| 47 |
if match:
|
| 48 |
encoder, dataset = match.groups()
|
| 49 |
+
encoder = encoder_storage_key(encoder)
|
| 50 |
assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
|
| 51 |
return assets
|
| 52 |
|
|
|
|
| 68 |
if not base_dir.exists():
|
| 69 |
return assets
|
| 70 |
for entry in sorted(base_dir.rglob("*.pt")):
|
| 71 |
+
encoder = encoder_storage_key(entry.parent.name)
|
| 72 |
+
prefix = f"imrnns-{entry.parent.name}-"
|
| 73 |
+
if not entry.name.startswith(prefix) or not entry.name.endswith(".pt"):
|
| 74 |
continue
|
| 75 |
+
dataset = entry.name.removeprefix(prefix).removesuffix(".pt")
|
|
|
|
|
|
|
| 76 |
assets.append(AssetMatch(encoder=encoder, dataset=dataset, path=entry))
|
| 77 |
return assets
|
| 78 |
|
| 79 |
|
| 80 |
def resolve_cache_dir(assets_root: Path, encoder: str, dataset: str) -> Path:
|
| 81 |
+
encoder = encoder_storage_key(encoder)
|
| 82 |
dataset = dataset.lower()
|
| 83 |
for asset in discover_cached_embeddings(assets_root):
|
| 84 |
if asset.encoder == encoder and asset.dataset.lower() == dataset:
|
| 85 |
return asset.path
|
| 86 |
+
direct = assets_root / f"cache_{encoder}_{dataset}"
|
| 87 |
+
if direct.exists():
|
| 88 |
+
return direct
|
| 89 |
raise FileNotFoundError(
|
| 90 |
f"No cached embeddings found for encoder='{encoder}' dataset='{dataset}' under {assets_root}"
|
| 91 |
)
|
| 92 |
|
| 93 |
|
| 94 |
def resolve_checkpoint_path(assets_root: Path, encoder: str, dataset: str) -> Optional[Path]:
|
| 95 |
+
encoder = encoder_storage_key(encoder)
|
| 96 |
dataset = dataset.lower()
|
| 97 |
for asset in discover_repo_checkpoints(package_root()):
|
| 98 |
if asset.encoder == encoder and asset.dataset.lower() == dataset:
|
| 99 |
return asset.path
|
| 100 |
for asset in discover_checkpoints(assets_root):
|
| 101 |
+
if encoder_storage_key(asset.encoder) == encoder and asset.dataset.lower() == dataset:
|
| 102 |
return asset.path
|
| 103 |
+
direct = assets_root / f"imrnns-{encoder}-{dataset}.pt"
|
| 104 |
+
if direct.exists():
|
| 105 |
+
return direct
|
| 106 |
return None
|
src/imrnns/checkpoints.py
CHANGED
|
@@ -6,14 +6,12 @@ from typing import Any
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
|
| 9 |
-
from .encoders import normalize_encoder_name
|
| 10 |
from .model import IMRNN, ModelConfig
|
| 11 |
|
| 12 |
|
| 13 |
def default_checkpoint_name(encoder: str, dataset: str) -> str:
|
| 14 |
-
|
| 15 |
-
display = "minilm" if normalized == "mini" else normalized
|
| 16 |
-
return f"imrnns-{display}-{dataset}.pt"
|
| 17 |
|
| 18 |
|
| 19 |
def sanitize_legacy_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
|
| 9 |
+
from .encoders import encoder_storage_key, normalize_encoder_name
|
| 10 |
from .model import IMRNN, ModelConfig
|
| 11 |
|
| 12 |
|
| 13 |
def default_checkpoint_name(encoder: str, dataset: str) -> str:
|
| 14 |
+
return f"imrnns-{encoder_storage_key(encoder)}-{dataset}.pt"
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def sanitize_legacy_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
|
src/imrnns/cli.py
CHANGED
|
@@ -17,7 +17,7 @@ from .beir_data import load_beir_source
|
|
| 17 |
from .caching import build_cache
|
| 18 |
from .checkpoints import default_checkpoint_name, load_model, save_checkpoint
|
| 19 |
from .data import ContrastiveCachedDataset, load_cached_split
|
| 20 |
-
from .encoders import normalize_encoder_name, resolve_encoder_spec
|
| 21 |
from .evaluation import evaluate_model
|
| 22 |
from .model import IMRNN, ModelConfig
|
| 23 |
from .training import TrainingConfig, train_model
|
|
@@ -25,6 +25,7 @@ from .training import TrainingConfig, train_model
|
|
| 25 |
|
| 26 |
def _add_common_args(parser: argparse.ArgumentParser) -> None:
|
| 27 |
parser.add_argument("--assets-root", type=Path, default=default_assets_root())
|
|
|
|
| 28 |
parser.add_argument("--encoder")
|
| 29 |
parser.add_argument("--encoder-model-name")
|
| 30 |
parser.add_argument("--embedding-dim", type=int)
|
|
@@ -46,9 +47,8 @@ def _resolve_encoder_spec(args: argparse.Namespace):
|
|
| 46 |
|
| 47 |
def _encoder_label(args: argparse.Namespace, encoder_spec) -> str:
|
| 48 |
if args.encoder:
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
return encoder_spec.key.replace("/", "-")
|
| 52 |
|
| 53 |
|
| 54 |
def _command_list_assets(args: argparse.Namespace) -> int:
|
|
@@ -76,7 +76,7 @@ def _load_training_inputs(args: argparse.Namespace):
|
|
| 76 |
encoder_spec = _resolve_encoder_spec(args)
|
| 77 |
encoder_label = _encoder_label(args, encoder_spec)
|
| 78 |
cache_dir = args.cache_dir or resolve_cache_dir(args.assets_root, encoder_label, args.dataset)
|
| 79 |
-
datasets_dir = args.assets_root / "datasets"
|
| 80 |
beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries)
|
| 81 |
train_split = load_cached_split(cache_dir, "train", beir_source, encoder_spec, args.device)
|
| 82 |
val_split = load_cached_split(cache_dir, "val", beir_source, encoder_spec, args.device)
|
|
@@ -96,7 +96,7 @@ def _command_cache(args: argparse.Namespace) -> int:
|
|
| 96 |
dataset_name=args.dataset,
|
| 97 |
encoder_spec=encoder_spec,
|
| 98 |
cache_dir=cache_dir,
|
| 99 |
-
datasets_dir=args.assets_root / "datasets",
|
| 100 |
device=args.device,
|
| 101 |
batch_size=args.batch_size,
|
| 102 |
num_negatives=args.num_negatives,
|
|
@@ -194,7 +194,7 @@ def _command_evaluate(args: argparse.Namespace) -> int:
|
|
| 194 |
f"No checkpoint found for encoder='{encoder_label}' dataset='{args.dataset}'. Provide --checkpoint."
|
| 195 |
)
|
| 196 |
|
| 197 |
-
datasets_dir = args.assets_root / "datasets"
|
| 198 |
beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries)
|
| 199 |
test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, args.device)
|
| 200 |
model, metadata, missing, unexpected = load_model(
|
|
|
|
| 17 |
from .caching import build_cache
|
| 18 |
from .checkpoints import default_checkpoint_name, load_model, save_checkpoint
|
| 19 |
from .data import ContrastiveCachedDataset, load_cached_split
|
| 20 |
+
from .encoders import encoder_storage_key, normalize_encoder_name, resolve_encoder_spec
|
| 21 |
from .evaluation import evaluate_model
|
| 22 |
from .model import IMRNN, ModelConfig
|
| 23 |
from .training import TrainingConfig, train_model
|
|
|
|
| 25 |
|
| 26 |
def _add_common_args(parser: argparse.ArgumentParser) -> None:
|
| 27 |
parser.add_argument("--assets-root", type=Path, default=default_assets_root())
|
| 28 |
+
parser.add_argument("--datasets-dir", type=Path)
|
| 29 |
parser.add_argument("--encoder")
|
| 30 |
parser.add_argument("--encoder-model-name")
|
| 31 |
parser.add_argument("--embedding-dim", type=int)
|
|
|
|
| 47 |
|
| 48 |
def _encoder_label(args: argparse.Namespace, encoder_spec) -> str:
|
| 49 |
if args.encoder:
|
| 50 |
+
return encoder_storage_key(args.encoder)
|
| 51 |
+
return encoder_storage_key(encoder_spec.key)
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
def _command_list_assets(args: argparse.Namespace) -> int:
|
|
|
|
| 76 |
encoder_spec = _resolve_encoder_spec(args)
|
| 77 |
encoder_label = _encoder_label(args, encoder_spec)
|
| 78 |
cache_dir = args.cache_dir or resolve_cache_dir(args.assets_root, encoder_label, args.dataset)
|
| 79 |
+
datasets_dir = args.datasets_dir or (args.assets_root / "datasets")
|
| 80 |
beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries)
|
| 81 |
train_split = load_cached_split(cache_dir, "train", beir_source, encoder_spec, args.device)
|
| 82 |
val_split = load_cached_split(cache_dir, "val", beir_source, encoder_spec, args.device)
|
|
|
|
| 96 |
dataset_name=args.dataset,
|
| 97 |
encoder_spec=encoder_spec,
|
| 98 |
cache_dir=cache_dir,
|
| 99 |
+
datasets_dir=args.datasets_dir or (args.assets_root / "datasets"),
|
| 100 |
device=args.device,
|
| 101 |
batch_size=args.batch_size,
|
| 102 |
num_negatives=args.num_negatives,
|
|
|
|
| 194 |
f"No checkpoint found for encoder='{encoder_label}' dataset='{args.dataset}'. Provide --checkpoint."
|
| 195 |
)
|
| 196 |
|
| 197 |
+
datasets_dir = args.datasets_dir or (args.assets_root / "datasets")
|
| 198 |
beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries)
|
| 199 |
test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, args.device)
|
| 200 |
model, metadata, missing, unexpected = load_model(
|
src/imrnns/encoders.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from typing import Optional
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
@dataclass(frozen=True)
|
|
@@ -47,6 +48,13 @@ def normalize_encoder_name(name: str) -> str:
|
|
| 47 |
return aliases.get(key, key)
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def get_encoder_spec(name: str) -> EncoderSpec:
|
| 51 |
key = normalize_encoder_name(name)
|
| 52 |
if key not in ENCODER_SPECS:
|
|
|
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from typing import Optional
|
| 5 |
+
import re
|
| 6 |
|
| 7 |
|
| 8 |
@dataclass(frozen=True)
|
|
|
|
| 48 |
return aliases.get(key, key)
|
| 49 |
|
| 50 |
|
| 51 |
+
def encoder_storage_key(name: str) -> str:
|
| 52 |
+
normalized = normalize_encoder_name(name)
|
| 53 |
+
if normalized == "mini":
|
| 54 |
+
return "minilm"
|
| 55 |
+
return re.sub(r"[^a-z0-9._-]+", "-", normalized.lower()).strip("-")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
def get_encoder_spec(name: str) -> EncoderSpec:
|
| 59 |
key = normalize_encoder_name(name)
|
| 60 |
if key not in ENCODER_SPECS:
|