from __future__ import annotations from pathlib import Path from typing import Any, Optional from .beir_data import load_beir_source from .caching import build_cache from .checkpoints import default_checkpoint_name, load_model, save_checkpoint from .data import ContrastiveCachedDataset, load_cached_split from .encoders import encoder_storage_key, resolve_encoder_spec from .evaluation import evaluate_model from .model import IMRNN, ModelConfig from .training import TrainingConfig, train_model def cache_embeddings( *, encoder: Optional[str], dataset: str, cache_dir: Path, datasets_dir: Path, device: str = "cpu", encoder_model_name: Optional[str] = None, embedding_dim: Optional[int] = None, query_prefix: str = "", passage_prefix: str = "", batch_size: int = 64, num_negatives: int = 20, negative_pool: int = 200, max_queries: Optional[int] = None, ) -> Path: encoder_spec = resolve_encoder_spec( encoder=encoder, encoder_model_name=encoder_model_name, embedding_dim=embedding_dim, query_prefix=query_prefix, passage_prefix=passage_prefix, ) return build_cache( dataset_name=dataset, encoder_spec=encoder_spec, cache_dir=cache_dir, datasets_dir=datasets_dir, device=device, batch_size=batch_size, num_negatives=num_negatives, negative_pool=negative_pool, max_queries=max_queries, ) def train( *, encoder: Optional[str], dataset: str, cache_dir: Path, datasets_dir: Path, output_dir: Path, device: str = "cpu", encoder_model_name: Optional[str] = None, embedding_dim: Optional[int] = None, query_prefix: str = "", passage_prefix: str = "", max_queries: Optional[int] = None, batch_size: int = 32, epochs: int = 10, lr: float = 1e-4, weight_decay: float = 1e-5, num_negatives: int = 20, output_dim: int = 256, hidden_dim: int = 128, dropout: float = 0.1, feedback_k: int = 100, ranking_k: int = 10, k: int = 10, ) -> dict[str, Any]: encoder_spec = resolve_encoder_spec( encoder=encoder, encoder_model_name=encoder_model_name, embedding_dim=embedding_dim, query_prefix=query_prefix, passage_prefix=passage_prefix, ) beir_source = load_beir_source(dataset, datasets_dir=datasets_dir, max_queries=max_queries) train_split = load_cached_split(cache_dir, "train", beir_source, encoder_spec, device) val_split = load_cached_split(cache_dir, "val", beir_source, encoder_spec, device) test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, device) model = IMRNN( ModelConfig( input_dim=encoder_spec.embedding_dim, output_dim=output_dim, hidden_dim=hidden_dim, dropout=dropout, ) ) train_dataset = ContrastiveCachedDataset(train_split, num_negatives) val_dataset = ContrastiveCachedDataset(val_split, num_negatives) if len(train_dataset) == 0: raise ValueError("No training examples were constructed from the cached training split.") if len(val_dataset) == 0: raise ValueError("No validation examples were constructed from the cached validation split.") training_metrics = train_model( model=model, train_dataset=train_dataset, val_dataset=val_dataset, config=TrainingConfig( batch_size=batch_size, epochs=epochs, lr=lr, weight_decay=weight_decay, num_negatives=num_negatives, ), device=device, ) evaluation_metrics = evaluate_model( model=model, cached_split=test_split, device=device, feedback_k=feedback_k, ranking_k=ranking_k, k_values=[k], ) checkpoint_stem = encoder_storage_key(encoder or encoder_spec.key) checkpoint_path = output_dir / default_checkpoint_name(checkpoint_stem, dataset) metadata = { "encoder": checkpoint_stem, "encoder_model_name": encoder_spec.model_name, "dataset": dataset, "cache_dir": str(cache_dir), "model_config": { "input_dim": encoder_spec.embedding_dim, "output_dim": output_dim, "hidden_dim": hidden_dim, "dropout": dropout, }, "training": training_metrics, "evaluation": evaluation_metrics, } save_checkpoint(checkpoint_path, model, metadata) return { "checkpoint": checkpoint_path, "training": training_metrics, "evaluation": evaluation_metrics, "metadata": metadata, } def evaluate( *, encoder: Optional[str], dataset: str, cache_dir: Path, datasets_dir: Path, checkpoint_path: Path, device: str = "cpu", encoder_model_name: Optional[str] = None, embedding_dim: Optional[int] = None, query_prefix: str = "", passage_prefix: str = "", max_queries: Optional[int] = None, output_dim: int = 256, hidden_dim: int = 128, dropout: float = 0.1, feedback_k: int = 100, ranking_k: int = 10, k: int = 10, ) -> dict[str, Any]: encoder_spec = resolve_encoder_spec( encoder=encoder, encoder_model_name=encoder_model_name, embedding_dim=embedding_dim, query_prefix=query_prefix, passage_prefix=passage_prefix, ) beir_source = load_beir_source(dataset, datasets_dir=datasets_dir, max_queries=max_queries) test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, device) model, metadata, missing, unexpected = load_model( checkpoint_path=checkpoint_path, model_config=ModelConfig( input_dim=encoder_spec.embedding_dim, output_dim=output_dim, hidden_dim=hidden_dim, dropout=dropout, ), device=device, ) metrics = evaluate_model( model=model, cached_split=test_split, device=device, feedback_k=feedback_k, ranking_k=ranking_k, k_values=[k], ) return { "checkpoint": checkpoint_path, "metrics": metrics, "metadata": metadata, "missing_keys": missing, "unexpected_keys": unexpected, } def run( *, encoder: Optional[str], dataset: str, cache_dir: Path, datasets_dir: Path, output_dir: Path, device: str = "cpu", encoder_model_name: Optional[str] = None, embedding_dim: Optional[int] = None, query_prefix: str = "", passage_prefix: str = "", max_queries: Optional[int] = None, batch_size: int = 32, epochs: int = 10, lr: float = 1e-4, weight_decay: float = 1e-5, num_negatives: int = 20, negative_pool: int = 200, output_dim: int = 256, hidden_dim: int = 128, dropout: float = 0.1, feedback_k: int = 100, ranking_k: int = 10, k: int = 10, ) -> dict[str, Any]: if not cache_dir.exists(): cache_embeddings( encoder=encoder, dataset=dataset, cache_dir=cache_dir, datasets_dir=datasets_dir, device=device, encoder_model_name=encoder_model_name, embedding_dim=embedding_dim, query_prefix=query_prefix, passage_prefix=passage_prefix, batch_size=batch_size, num_negatives=num_negatives, negative_pool=negative_pool, max_queries=max_queries, ) return train( encoder=encoder, dataset=dataset, cache_dir=cache_dir, datasets_dir=datasets_dir, output_dir=output_dir, device=device, encoder_model_name=encoder_model_name, embedding_dim=embedding_dim, query_prefix=query_prefix, passage_prefix=passage_prefix, max_queries=max_queries, batch_size=batch_size, epochs=epochs, lr=lr, weight_decay=weight_decay, num_negatives=num_negatives, output_dim=output_dim, hidden_dim=hidden_dim, dropout=dropout, feedback_k=feedback_k, ranking_k=ranking_k, k=k, )