IMRNNs / src /imrnns /api.py
yashsaxena21's picture
Upload folder using huggingface_hub
14e3943 verified
from __future__ import annotations
from pathlib import Path
from typing import Any, Optional
from .beir_data import load_beir_source
from .caching import build_cache
from .checkpoints import default_checkpoint_name, load_model, save_checkpoint
from .data import ContrastiveCachedDataset, load_cached_split
from .encoders import encoder_storage_key, resolve_encoder_spec
from .evaluation import evaluate_model
from .model import IMRNN, ModelConfig
from .training import TrainingConfig, train_model
def cache_embeddings(
*,
encoder: Optional[str],
dataset: str,
cache_dir: Path,
datasets_dir: Path,
device: str = "cpu",
encoder_model_name: Optional[str] = None,
embedding_dim: Optional[int] = None,
query_prefix: str = "",
passage_prefix: str = "",
batch_size: int = 64,
num_negatives: int = 20,
negative_pool: int = 200,
max_queries: Optional[int] = None,
) -> Path:
encoder_spec = resolve_encoder_spec(
encoder=encoder,
encoder_model_name=encoder_model_name,
embedding_dim=embedding_dim,
query_prefix=query_prefix,
passage_prefix=passage_prefix,
)
return build_cache(
dataset_name=dataset,
encoder_spec=encoder_spec,
cache_dir=cache_dir,
datasets_dir=datasets_dir,
device=device,
batch_size=batch_size,
num_negatives=num_negatives,
negative_pool=negative_pool,
max_queries=max_queries,
)
def train(
*,
encoder: Optional[str],
dataset: str,
cache_dir: Path,
datasets_dir: Path,
output_dir: Path,
device: str = "cpu",
encoder_model_name: Optional[str] = None,
embedding_dim: Optional[int] = None,
query_prefix: str = "",
passage_prefix: str = "",
max_queries: Optional[int] = None,
batch_size: int = 32,
epochs: int = 10,
lr: float = 1e-4,
weight_decay: float = 1e-5,
num_negatives: int = 20,
output_dim: int = 256,
hidden_dim: int = 128,
dropout: float = 0.1,
feedback_k: int = 100,
ranking_k: int = 10,
k: int = 10,
) -> dict[str, Any]:
encoder_spec = resolve_encoder_spec(
encoder=encoder,
encoder_model_name=encoder_model_name,
embedding_dim=embedding_dim,
query_prefix=query_prefix,
passage_prefix=passage_prefix,
)
beir_source = load_beir_source(dataset, datasets_dir=datasets_dir, max_queries=max_queries)
train_split = load_cached_split(cache_dir, "train", beir_source, encoder_spec, device)
val_split = load_cached_split(cache_dir, "val", beir_source, encoder_spec, device)
test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, device)
model = IMRNN(
ModelConfig(
input_dim=encoder_spec.embedding_dim,
output_dim=output_dim,
hidden_dim=hidden_dim,
dropout=dropout,
)
)
train_dataset = ContrastiveCachedDataset(train_split, num_negatives)
val_dataset = ContrastiveCachedDataset(val_split, num_negatives)
if len(train_dataset) == 0:
raise ValueError("No training examples were constructed from the cached training split.")
if len(val_dataset) == 0:
raise ValueError("No validation examples were constructed from the cached validation split.")
training_metrics = train_model(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
config=TrainingConfig(
batch_size=batch_size,
epochs=epochs,
lr=lr,
weight_decay=weight_decay,
num_negatives=num_negatives,
),
device=device,
)
evaluation_metrics = evaluate_model(
model=model,
cached_split=test_split,
device=device,
feedback_k=feedback_k,
ranking_k=ranking_k,
k_values=[k],
)
checkpoint_stem = encoder_storage_key(encoder or encoder_spec.key)
checkpoint_path = output_dir / default_checkpoint_name(checkpoint_stem, dataset)
metadata = {
"encoder": checkpoint_stem,
"encoder_model_name": encoder_spec.model_name,
"dataset": dataset,
"cache_dir": str(cache_dir),
"model_config": {
"input_dim": encoder_spec.embedding_dim,
"output_dim": output_dim,
"hidden_dim": hidden_dim,
"dropout": dropout,
},
"training": training_metrics,
"evaluation": evaluation_metrics,
}
save_checkpoint(checkpoint_path, model, metadata)
return {
"checkpoint": checkpoint_path,
"training": training_metrics,
"evaluation": evaluation_metrics,
"metadata": metadata,
}
def evaluate(
*,
encoder: Optional[str],
dataset: str,
cache_dir: Path,
datasets_dir: Path,
checkpoint_path: Path,
device: str = "cpu",
encoder_model_name: Optional[str] = None,
embedding_dim: Optional[int] = None,
query_prefix: str = "",
passage_prefix: str = "",
max_queries: Optional[int] = None,
output_dim: int = 256,
hidden_dim: int = 128,
dropout: float = 0.1,
feedback_k: int = 100,
ranking_k: int = 10,
k: int = 10,
) -> dict[str, Any]:
encoder_spec = resolve_encoder_spec(
encoder=encoder,
encoder_model_name=encoder_model_name,
embedding_dim=embedding_dim,
query_prefix=query_prefix,
passage_prefix=passage_prefix,
)
beir_source = load_beir_source(dataset, datasets_dir=datasets_dir, max_queries=max_queries)
test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, device)
model, metadata, missing, unexpected = load_model(
checkpoint_path=checkpoint_path,
model_config=ModelConfig(
input_dim=encoder_spec.embedding_dim,
output_dim=output_dim,
hidden_dim=hidden_dim,
dropout=dropout,
),
device=device,
)
metrics = evaluate_model(
model=model,
cached_split=test_split,
device=device,
feedback_k=feedback_k,
ranking_k=ranking_k,
k_values=[k],
)
return {
"checkpoint": checkpoint_path,
"metrics": metrics,
"metadata": metadata,
"missing_keys": missing,
"unexpected_keys": unexpected,
}
def run(
*,
encoder: Optional[str],
dataset: str,
cache_dir: Path,
datasets_dir: Path,
output_dir: Path,
device: str = "cpu",
encoder_model_name: Optional[str] = None,
embedding_dim: Optional[int] = None,
query_prefix: str = "",
passage_prefix: str = "",
max_queries: Optional[int] = None,
batch_size: int = 32,
epochs: int = 10,
lr: float = 1e-4,
weight_decay: float = 1e-5,
num_negatives: int = 20,
negative_pool: int = 200,
output_dim: int = 256,
hidden_dim: int = 128,
dropout: float = 0.1,
feedback_k: int = 100,
ranking_k: int = 10,
k: int = 10,
) -> dict[str, Any]:
if not cache_dir.exists():
cache_embeddings(
encoder=encoder,
dataset=dataset,
cache_dir=cache_dir,
datasets_dir=datasets_dir,
device=device,
encoder_model_name=encoder_model_name,
embedding_dim=embedding_dim,
query_prefix=query_prefix,
passage_prefix=passage_prefix,
batch_size=batch_size,
num_negatives=num_negatives,
negative_pool=negative_pool,
max_queries=max_queries,
)
return train(
encoder=encoder,
dataset=dataset,
cache_dir=cache_dir,
datasets_dir=datasets_dir,
output_dir=output_dir,
device=device,
encoder_model_name=encoder_model_name,
embedding_dim=embedding_dim,
query_prefix=query_prefix,
passage_prefix=passage_prefix,
max_queries=max_queries,
batch_size=batch_size,
epochs=epochs,
lr=lr,
weight_decay=weight_decay,
num_negatives=num_negatives,
output_dim=output_dim,
hidden_dim=hidden_dim,
dropout=dropout,
feedback_k=feedback_k,
ranking_k=ranking_k,
k=k,
)