from __future__ import annotations import argparse import json from pathlib import Path from .assets import ( default_assets_root, discover_cached_embeddings, discover_checkpoints, discover_repo_checkpoints, package_root, resolve_cache_dir, resolve_checkpoint_path, ) from .beir_data import load_beir_source from .caching import build_cache from .checkpoints import default_checkpoint_name, load_model, save_checkpoint from .data import ContrastiveCachedDataset, load_cached_split from .encoders import encoder_storage_key, normalize_encoder_name, resolve_encoder_spec from .evaluation import evaluate_model from .model import IMRNN, ModelConfig from .training import TrainingConfig, train_model def _add_common_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--assets-root", type=Path, default=default_assets_root()) parser.add_argument("--datasets-dir", type=Path) parser.add_argument("--encoder") parser.add_argument("--encoder-model-name") parser.add_argument("--embedding-dim", type=int) parser.add_argument("--query-prefix", default="") parser.add_argument("--passage-prefix", default="") parser.add_argument("--dataset", required=True) parser.add_argument("--device", default="cuda") def _resolve_encoder_spec(args: argparse.Namespace): return resolve_encoder_spec( encoder=args.encoder, encoder_model_name=args.encoder_model_name, embedding_dim=args.embedding_dim, query_prefix=args.query_prefix, passage_prefix=args.passage_prefix, ) def _encoder_label(args: argparse.Namespace, encoder_spec) -> str: if args.encoder: return encoder_storage_key(args.encoder) return encoder_storage_key(encoder_spec.key) def _command_list_assets(args: argparse.Namespace) -> int: payload = { "assets_root": str(args.assets_root), "repo_root": str(package_root()), "cached_embeddings": [ {"encoder": item.encoder, "dataset": item.dataset, "path": str(item.path)} for item in discover_cached_embeddings(args.assets_root) ], "workspace_checkpoints": [ {"encoder": item.encoder, "dataset": item.dataset, "path": str(item.path)} for item in discover_checkpoints(args.assets_root) ], "repo_checkpoints": [ {"encoder": item.encoder, "dataset": item.dataset, "path": str(item.path)} for item in discover_repo_checkpoints(package_root()) ], } print(json.dumps(payload, indent=2)) return 0 def _load_training_inputs(args: argparse.Namespace): encoder_spec = _resolve_encoder_spec(args) encoder_label = _encoder_label(args, encoder_spec) cache_dir = args.cache_dir or resolve_cache_dir(args.assets_root, encoder_label, args.dataset) datasets_dir = args.datasets_dir or (args.assets_root / "datasets") beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries) train_split = load_cached_split(cache_dir, "train", beir_source, encoder_spec, args.device) val_split = load_cached_split(cache_dir, "val", beir_source, encoder_spec, args.device) test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, args.device) return encoder_spec, cache_dir, train_split, val_split, test_split def _k_values(args: argparse.Namespace) -> list[int]: return [args.k] def _command_cache(args: argparse.Namespace) -> int: encoder_spec = _resolve_encoder_spec(args) encoder_label = _encoder_label(args, encoder_spec) cache_dir = args.cache_dir or (args.assets_root / f"cache_{encoder_label}_{args.dataset}") built = build_cache( dataset_name=args.dataset, encoder_spec=encoder_spec, cache_dir=cache_dir, datasets_dir=args.datasets_dir or (args.assets_root / "datasets"), device=args.device, batch_size=args.batch_size, num_negatives=args.num_negatives, negative_pool=args.negative_pool, ) print( json.dumps( { "cache_dir": str(built), "encoder": encoder_label, "encoder_model_name": encoder_spec.model_name, "dataset": args.dataset, }, indent=2, ) ) return 0 def _command_train(args: argparse.Namespace) -> int: encoder_spec, cache_dir, train_split, val_split, test_split = _load_training_inputs(args) model = IMRNN( ModelConfig( input_dim=encoder_spec.embedding_dim, output_dim=args.output_dim, hidden_dim=args.hidden_dim, dropout=args.dropout, ) ) train_dataset = ContrastiveCachedDataset(train_split, args.num_negatives) val_dataset = ContrastiveCachedDataset(val_split, args.num_negatives) if len(train_dataset) == 0: raise ValueError( "No training examples were constructed from the cached split. " "Increase --max-queries or verify that cached negatives/embeddings match the dataset split." ) if len(val_dataset) == 0: raise ValueError( "No validation examples were constructed from the cached split. " "Increase --max-queries or verify that cached negatives/embeddings match the dataset split." ) metrics = train_model( model=model, train_dataset=train_dataset, val_dataset=val_dataset, config=TrainingConfig( batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay, num_negatives=args.num_negatives, ), device=args.device, ) eval_metrics = evaluate_model( model=model, cached_split=test_split, device=args.device, feedback_k=args.feedback_k, ranking_k=args.ranking_k, k_values=_k_values(args), ) output_dir = args.output_dir or args.assets_root encoder_label = _encoder_label(args, encoder_spec) checkpoint_path = output_dir / default_checkpoint_name(encoder_label, args.dataset) metadata = { "encoder": encoder_label, "encoder_model_name": encoder_spec.model_name, "dataset": args.dataset, "cache_dir": str(cache_dir), "model_config": { "input_dim": encoder_spec.embedding_dim, "output_dim": args.output_dim, "hidden_dim": args.hidden_dim, "dropout": args.dropout, }, "training": metrics, "evaluation": eval_metrics, } save_checkpoint(checkpoint_path, model, metadata) print(json.dumps({"checkpoint": str(checkpoint_path), "training": metrics, "evaluation": eval_metrics}, indent=2)) return 0 def _command_evaluate(args: argparse.Namespace) -> int: encoder_spec = _resolve_encoder_spec(args) encoder_label = _encoder_label(args, encoder_spec) cache_dir = args.cache_dir or resolve_cache_dir(args.assets_root, encoder_label, args.dataset) checkpoint_path = args.checkpoint or resolve_checkpoint_path(args.assets_root, encoder_label, args.dataset) if checkpoint_path is None: raise FileNotFoundError( f"No checkpoint found for encoder='{encoder_label}' dataset='{args.dataset}'. Provide --checkpoint." ) datasets_dir = args.datasets_dir or (args.assets_root / "datasets") beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries) test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, args.device) model, metadata, missing, unexpected = load_model( checkpoint_path=checkpoint_path, model_config=ModelConfig( input_dim=encoder_spec.embedding_dim, output_dim=args.output_dim, hidden_dim=args.hidden_dim, dropout=args.dropout, ), device=args.device, ) metrics = evaluate_model( model=model, cached_split=test_split, device=args.device, feedback_k=args.feedback_k, ranking_k=args.ranking_k, k_values=_k_values(args), ) print( json.dumps( { "checkpoint": str(checkpoint_path), "metrics": metrics, "metadata": metadata, "missing_keys": missing, "unexpected_keys": unexpected, }, indent=2, ) ) return 0 def _command_run(args: argparse.Namespace) -> int: encoder_spec = _resolve_encoder_spec(args) encoder_label = _encoder_label(args, encoder_spec) cache_dir = args.cache_dir or (args.assets_root / f"cache_{encoder_label}_{args.dataset}") if not cache_dir.exists(): cache_args = argparse.Namespace(**vars(args)) cache_args.cache_dir = cache_dir _command_cache(cache_args) train_args = argparse.Namespace(**vars(args)) train_args.cache_dir = cache_dir return _command_train(train_args) def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Train and evaluate IMRNNs over cached BEIR embeddings.") subparsers = parser.add_subparsers(dest="command", required=True) list_assets = subparsers.add_parser("list-assets", help="List cached embeddings and checkpoints.") list_assets.add_argument("--assets-root", type=Path, default=default_assets_root()) list_assets.set_defaults(func=_command_list_assets) train = subparsers.add_parser("train", help="Train IMRNNs from cached embeddings.") _add_common_args(train) train.add_argument("--cache-dir", type=Path) train.add_argument("--output-dir", type=Path) train.add_argument("--max-queries", type=int) train.add_argument("--batch-size", type=int, default=32) train.add_argument("--epochs", type=int, default=10) train.add_argument("--lr", type=float, default=1e-4) train.add_argument("--weight-decay", type=float, default=1e-5) train.add_argument("--num-negatives", type=int, default=20) train.add_argument("--output-dim", type=int, default=256) train.add_argument("--hidden-dim", type=int, default=128) train.add_argument("--dropout", type=float, default=0.1) train.add_argument("--feedback-k", type=int, default=100) train.add_argument("--ranking-k", type=int, default=10) train.add_argument("--k", type=int, default=10) train.set_defaults(func=_command_train) evaluate = subparsers.add_parser("evaluate", help="Evaluate an IMRNN checkpoint.") _add_common_args(evaluate) evaluate.add_argument("--cache-dir", type=Path) evaluate.add_argument("--checkpoint", type=Path) evaluate.add_argument("--max-queries", type=int) evaluate.add_argument("--output-dim", type=int, default=256) evaluate.add_argument("--hidden-dim", type=int, default=128) evaluate.add_argument("--dropout", type=float, default=0.1) evaluate.add_argument("--feedback-k", type=int, default=100) evaluate.add_argument("--ranking-k", type=int, default=10) evaluate.add_argument("--k", type=int, default=10) evaluate.set_defaults(func=_command_evaluate) cache = subparsers.add_parser("cache", help="Download a BEIR dataset and cache embeddings plus negatives.") _add_common_args(cache) cache.add_argument("--cache-dir", type=Path) cache.add_argument("--batch-size", type=int, default=64) cache.add_argument("--num-negatives", type=int, default=20) cache.add_argument("--negative-pool", type=int, default=200) cache.set_defaults(func=_command_cache) run = subparsers.add_parser("run", help="Cache embeddings if needed, then train and evaluate IMRNNs end to end.") _add_common_args(run) run.add_argument("--cache-dir", type=Path) run.add_argument("--output-dir", type=Path) run.add_argument("--max-queries", type=int) run.add_argument("--batch-size", type=int, default=32) run.add_argument("--epochs", type=int, default=10) run.add_argument("--lr", type=float, default=1e-4) run.add_argument("--weight-decay", type=float, default=1e-5) run.add_argument("--num-negatives", type=int, default=20) run.add_argument("--negative-pool", type=int, default=200) run.add_argument("--output-dim", type=int, default=256) run.add_argument("--hidden-dim", type=int, default=128) run.add_argument("--dropout", type=float, default=0.1) run.add_argument("--feedback-k", type=int, default=100) run.add_argument("--ranking-k", type=int, default=10) run.add_argument("--k", type=int, default=10) run.set_defaults(func=_command_run) return parser def main() -> int: parser = build_parser() args = parser.parse_args() return args.func(args)