| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| from .beir_data import load_beir_source | |
| from .caching import build_cache | |
| from .checkpoints import default_checkpoint_name, load_model, save_checkpoint | |
| from .data import ContrastiveCachedDataset, load_cached_split | |
| from .encoders import encoder_storage_key, resolve_encoder_spec | |
| from .evaluation import evaluate_model | |
| from .model import IMRNN, ModelConfig | |
| from .training import TrainingConfig, train_model | |
| def cache_embeddings( | |
| *, | |
| encoder: Optional[str], | |
| dataset: str, | |
| cache_dir: Path, | |
| datasets_dir: Path, | |
| device: str = "cpu", | |
| encoder_model_name: Optional[str] = None, | |
| embedding_dim: Optional[int] = None, | |
| query_prefix: str = "", | |
| passage_prefix: str = "", | |
| batch_size: int = 64, | |
| num_negatives: int = 20, | |
| negative_pool: int = 200, | |
| max_queries: Optional[int] = None, | |
| ) -> Path: | |
| encoder_spec = resolve_encoder_spec( | |
| encoder=encoder, | |
| encoder_model_name=encoder_model_name, | |
| embedding_dim=embedding_dim, | |
| query_prefix=query_prefix, | |
| passage_prefix=passage_prefix, | |
| ) | |
| return build_cache( | |
| dataset_name=dataset, | |
| encoder_spec=encoder_spec, | |
| cache_dir=cache_dir, | |
| datasets_dir=datasets_dir, | |
| device=device, | |
| batch_size=batch_size, | |
| num_negatives=num_negatives, | |
| negative_pool=negative_pool, | |
| max_queries=max_queries, | |
| ) | |
| def train( | |
| *, | |
| encoder: Optional[str], | |
| dataset: str, | |
| cache_dir: Path, | |
| datasets_dir: Path, | |
| output_dir: Path, | |
| device: str = "cpu", | |
| encoder_model_name: Optional[str] = None, | |
| embedding_dim: Optional[int] = None, | |
| query_prefix: str = "", | |
| passage_prefix: str = "", | |
| max_queries: Optional[int] = None, | |
| batch_size: int = 32, | |
| epochs: int = 10, | |
| lr: float = 1e-4, | |
| weight_decay: float = 1e-5, | |
| num_negatives: int = 20, | |
| output_dim: int = 256, | |
| hidden_dim: int = 128, | |
| dropout: float = 0.1, | |
| feedback_k: int = 100, | |
| ranking_k: int = 10, | |
| k: int = 10, | |
| ) -> dict[str, Any]: | |
| encoder_spec = resolve_encoder_spec( | |
| encoder=encoder, | |
| encoder_model_name=encoder_model_name, | |
| embedding_dim=embedding_dim, | |
| query_prefix=query_prefix, | |
| passage_prefix=passage_prefix, | |
| ) | |
| beir_source = load_beir_source(dataset, datasets_dir=datasets_dir, max_queries=max_queries) | |
| train_split = load_cached_split(cache_dir, "train", beir_source, encoder_spec, device) | |
| val_split = load_cached_split(cache_dir, "val", beir_source, encoder_spec, device) | |
| test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, device) | |
| model = IMRNN( | |
| ModelConfig( | |
| input_dim=encoder_spec.embedding_dim, | |
| output_dim=output_dim, | |
| hidden_dim=hidden_dim, | |
| dropout=dropout, | |
| ) | |
| ) | |
| train_dataset = ContrastiveCachedDataset(train_split, num_negatives) | |
| val_dataset = ContrastiveCachedDataset(val_split, num_negatives) | |
| if len(train_dataset) == 0: | |
| raise ValueError("No training examples were constructed from the cached training split.") | |
| if len(val_dataset) == 0: | |
| raise ValueError("No validation examples were constructed from the cached validation split.") | |
| training_metrics = train_model( | |
| model=model, | |
| train_dataset=train_dataset, | |
| val_dataset=val_dataset, | |
| config=TrainingConfig( | |
| batch_size=batch_size, | |
| epochs=epochs, | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| num_negatives=num_negatives, | |
| ), | |
| device=device, | |
| ) | |
| evaluation_metrics = evaluate_model( | |
| model=model, | |
| cached_split=test_split, | |
| device=device, | |
| feedback_k=feedback_k, | |
| ranking_k=ranking_k, | |
| k_values=[k], | |
| ) | |
| checkpoint_stem = encoder_storage_key(encoder or encoder_spec.key) | |
| checkpoint_path = output_dir / default_checkpoint_name(checkpoint_stem, dataset) | |
| metadata = { | |
| "encoder": checkpoint_stem, | |
| "encoder_model_name": encoder_spec.model_name, | |
| "dataset": dataset, | |
| "cache_dir": str(cache_dir), | |
| "model_config": { | |
| "input_dim": encoder_spec.embedding_dim, | |
| "output_dim": output_dim, | |
| "hidden_dim": hidden_dim, | |
| "dropout": dropout, | |
| }, | |
| "training": training_metrics, | |
| "evaluation": evaluation_metrics, | |
| } | |
| save_checkpoint(checkpoint_path, model, metadata) | |
| return { | |
| "checkpoint": checkpoint_path, | |
| "training": training_metrics, | |
| "evaluation": evaluation_metrics, | |
| "metadata": metadata, | |
| } | |
| def evaluate( | |
| *, | |
| encoder: Optional[str], | |
| dataset: str, | |
| cache_dir: Path, | |
| datasets_dir: Path, | |
| checkpoint_path: Path, | |
| device: str = "cpu", | |
| encoder_model_name: Optional[str] = None, | |
| embedding_dim: Optional[int] = None, | |
| query_prefix: str = "", | |
| passage_prefix: str = "", | |
| max_queries: Optional[int] = None, | |
| output_dim: int = 256, | |
| hidden_dim: int = 128, | |
| dropout: float = 0.1, | |
| feedback_k: int = 100, | |
| ranking_k: int = 10, | |
| k: int = 10, | |
| ) -> dict[str, Any]: | |
| encoder_spec = resolve_encoder_spec( | |
| encoder=encoder, | |
| encoder_model_name=encoder_model_name, | |
| embedding_dim=embedding_dim, | |
| query_prefix=query_prefix, | |
| passage_prefix=passage_prefix, | |
| ) | |
| beir_source = load_beir_source(dataset, datasets_dir=datasets_dir, max_queries=max_queries) | |
| test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, device) | |
| model, metadata, missing, unexpected = load_model( | |
| checkpoint_path=checkpoint_path, | |
| model_config=ModelConfig( | |
| input_dim=encoder_spec.embedding_dim, | |
| output_dim=output_dim, | |
| hidden_dim=hidden_dim, | |
| dropout=dropout, | |
| ), | |
| device=device, | |
| ) | |
| metrics = evaluate_model( | |
| model=model, | |
| cached_split=test_split, | |
| device=device, | |
| feedback_k=feedback_k, | |
| ranking_k=ranking_k, | |
| k_values=[k], | |
| ) | |
| return { | |
| "checkpoint": checkpoint_path, | |
| "metrics": metrics, | |
| "metadata": metadata, | |
| "missing_keys": missing, | |
| "unexpected_keys": unexpected, | |
| } | |
| def run( | |
| *, | |
| encoder: Optional[str], | |
| dataset: str, | |
| cache_dir: Path, | |
| datasets_dir: Path, | |
| output_dir: Path, | |
| device: str = "cpu", | |
| encoder_model_name: Optional[str] = None, | |
| embedding_dim: Optional[int] = None, | |
| query_prefix: str = "", | |
| passage_prefix: str = "", | |
| max_queries: Optional[int] = None, | |
| batch_size: int = 32, | |
| epochs: int = 10, | |
| lr: float = 1e-4, | |
| weight_decay: float = 1e-5, | |
| num_negatives: int = 20, | |
| negative_pool: int = 200, | |
| output_dim: int = 256, | |
| hidden_dim: int = 128, | |
| dropout: float = 0.1, | |
| feedback_k: int = 100, | |
| ranking_k: int = 10, | |
| k: int = 10, | |
| ) -> dict[str, Any]: | |
| if not cache_dir.exists(): | |
| cache_embeddings( | |
| encoder=encoder, | |
| dataset=dataset, | |
| cache_dir=cache_dir, | |
| datasets_dir=datasets_dir, | |
| device=device, | |
| encoder_model_name=encoder_model_name, | |
| embedding_dim=embedding_dim, | |
| query_prefix=query_prefix, | |
| passage_prefix=passage_prefix, | |
| batch_size=batch_size, | |
| num_negatives=num_negatives, | |
| negative_pool=negative_pool, | |
| max_queries=max_queries, | |
| ) | |
| return train( | |
| encoder=encoder, | |
| dataset=dataset, | |
| cache_dir=cache_dir, | |
| datasets_dir=datasets_dir, | |
| output_dir=output_dir, | |
| device=device, | |
| encoder_model_name=encoder_model_name, | |
| embedding_dim=embedding_dim, | |
| query_prefix=query_prefix, | |
| passage_prefix=passage_prefix, | |
| max_queries=max_queries, | |
| batch_size=batch_size, | |
| epochs=epochs, | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| num_negatives=num_negatives, | |
| output_dim=output_dim, | |
| hidden_dim=hidden_dim, | |
| dropout=dropout, | |
| feedback_k=feedback_k, | |
| ranking_k=ranking_k, | |
| k=k, | |
| ) | |