| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from .assets import ( | |
| default_assets_root, | |
| discover_cached_embeddings, | |
| discover_checkpoints, | |
| discover_repo_checkpoints, | |
| package_root, | |
| resolve_cache_dir, | |
| resolve_checkpoint_path, | |
| ) | |
| from .beir_data import load_beir_source | |
| from .caching import build_cache | |
| from .checkpoints import default_checkpoint_name, load_model, save_checkpoint | |
| from .data import ContrastiveCachedDataset, load_cached_split | |
| from .encoders import encoder_storage_key, normalize_encoder_name, resolve_encoder_spec | |
| from .evaluation import evaluate_model | |
| from .model import IMRNN, ModelConfig | |
| from .training import TrainingConfig, train_model | |
| def _add_common_args(parser: argparse.ArgumentParser) -> None: | |
| parser.add_argument("--assets-root", type=Path, default=default_assets_root()) | |
| parser.add_argument("--datasets-dir", type=Path) | |
| parser.add_argument("--encoder") | |
| parser.add_argument("--encoder-model-name") | |
| parser.add_argument("--embedding-dim", type=int) | |
| parser.add_argument("--query-prefix", default="") | |
| parser.add_argument("--passage-prefix", default="") | |
| parser.add_argument("--dataset", required=True) | |
| parser.add_argument("--device", default="cuda") | |
| def _resolve_encoder_spec(args: argparse.Namespace): | |
| return resolve_encoder_spec( | |
| encoder=args.encoder, | |
| encoder_model_name=args.encoder_model_name, | |
| embedding_dim=args.embedding_dim, | |
| query_prefix=args.query_prefix, | |
| passage_prefix=args.passage_prefix, | |
| ) | |
| def _encoder_label(args: argparse.Namespace, encoder_spec) -> str: | |
| if args.encoder: | |
| return encoder_storage_key(args.encoder) | |
| return encoder_storage_key(encoder_spec.key) | |
| def _command_list_assets(args: argparse.Namespace) -> int: | |
| payload = { | |
| "assets_root": str(args.assets_root), | |
| "repo_root": str(package_root()), | |
| "cached_embeddings": [ | |
| {"encoder": item.encoder, "dataset": item.dataset, "path": str(item.path)} | |
| for item in discover_cached_embeddings(args.assets_root) | |
| ], | |
| "workspace_checkpoints": [ | |
| {"encoder": item.encoder, "dataset": item.dataset, "path": str(item.path)} | |
| for item in discover_checkpoints(args.assets_root) | |
| ], | |
| "repo_checkpoints": [ | |
| {"encoder": item.encoder, "dataset": item.dataset, "path": str(item.path)} | |
| for item in discover_repo_checkpoints(package_root()) | |
| ], | |
| } | |
| print(json.dumps(payload, indent=2)) | |
| return 0 | |
| def _load_training_inputs(args: argparse.Namespace): | |
| encoder_spec = _resolve_encoder_spec(args) | |
| encoder_label = _encoder_label(args, encoder_spec) | |
| cache_dir = args.cache_dir or resolve_cache_dir(args.assets_root, encoder_label, args.dataset) | |
| datasets_dir = args.datasets_dir or (args.assets_root / "datasets") | |
| beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries) | |
| train_split = load_cached_split(cache_dir, "train", beir_source, encoder_spec, args.device) | |
| val_split = load_cached_split(cache_dir, "val", beir_source, encoder_spec, args.device) | |
| test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, args.device) | |
| return encoder_spec, cache_dir, train_split, val_split, test_split | |
| def _k_values(args: argparse.Namespace) -> list[int]: | |
| return [args.k] | |
| def _command_cache(args: argparse.Namespace) -> int: | |
| encoder_spec = _resolve_encoder_spec(args) | |
| encoder_label = _encoder_label(args, encoder_spec) | |
| cache_dir = args.cache_dir or (args.assets_root / f"cache_{encoder_label}_{args.dataset}") | |
| built = build_cache( | |
| dataset_name=args.dataset, | |
| encoder_spec=encoder_spec, | |
| cache_dir=cache_dir, | |
| datasets_dir=args.datasets_dir or (args.assets_root / "datasets"), | |
| device=args.device, | |
| batch_size=args.batch_size, | |
| num_negatives=args.num_negatives, | |
| negative_pool=args.negative_pool, | |
| ) | |
| print( | |
| json.dumps( | |
| { | |
| "cache_dir": str(built), | |
| "encoder": encoder_label, | |
| "encoder_model_name": encoder_spec.model_name, | |
| "dataset": args.dataset, | |
| }, | |
| indent=2, | |
| ) | |
| ) | |
| return 0 | |
| def _command_train(args: argparse.Namespace) -> int: | |
| encoder_spec, cache_dir, train_split, val_split, test_split = _load_training_inputs(args) | |
| model = IMRNN( | |
| ModelConfig( | |
| input_dim=encoder_spec.embedding_dim, | |
| output_dim=args.output_dim, | |
| hidden_dim=args.hidden_dim, | |
| dropout=args.dropout, | |
| ) | |
| ) | |
| train_dataset = ContrastiveCachedDataset(train_split, args.num_negatives) | |
| val_dataset = ContrastiveCachedDataset(val_split, args.num_negatives) | |
| if len(train_dataset) == 0: | |
| raise ValueError( | |
| "No training examples were constructed from the cached split. " | |
| "Increase --max-queries or verify that cached negatives/embeddings match the dataset split." | |
| ) | |
| if len(val_dataset) == 0: | |
| raise ValueError( | |
| "No validation examples were constructed from the cached split. " | |
| "Increase --max-queries or verify that cached negatives/embeddings match the dataset split." | |
| ) | |
| metrics = train_model( | |
| model=model, | |
| train_dataset=train_dataset, | |
| val_dataset=val_dataset, | |
| config=TrainingConfig( | |
| batch_size=args.batch_size, | |
| epochs=args.epochs, | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| num_negatives=args.num_negatives, | |
| ), | |
| device=args.device, | |
| ) | |
| eval_metrics = evaluate_model( | |
| model=model, | |
| cached_split=test_split, | |
| device=args.device, | |
| feedback_k=args.feedback_k, | |
| ranking_k=args.ranking_k, | |
| k_values=_k_values(args), | |
| ) | |
| output_dir = args.output_dir or args.assets_root | |
| encoder_label = _encoder_label(args, encoder_spec) | |
| checkpoint_path = output_dir / default_checkpoint_name(encoder_label, args.dataset) | |
| metadata = { | |
| "encoder": encoder_label, | |
| "encoder_model_name": encoder_spec.model_name, | |
| "dataset": args.dataset, | |
| "cache_dir": str(cache_dir), | |
| "model_config": { | |
| "input_dim": encoder_spec.embedding_dim, | |
| "output_dim": args.output_dim, | |
| "hidden_dim": args.hidden_dim, | |
| "dropout": args.dropout, | |
| }, | |
| "training": metrics, | |
| "evaluation": eval_metrics, | |
| } | |
| save_checkpoint(checkpoint_path, model, metadata) | |
| print(json.dumps({"checkpoint": str(checkpoint_path), "training": metrics, "evaluation": eval_metrics}, indent=2)) | |
| return 0 | |
| def _command_evaluate(args: argparse.Namespace) -> int: | |
| encoder_spec = _resolve_encoder_spec(args) | |
| encoder_label = _encoder_label(args, encoder_spec) | |
| cache_dir = args.cache_dir or resolve_cache_dir(args.assets_root, encoder_label, args.dataset) | |
| checkpoint_path = args.checkpoint or resolve_checkpoint_path(args.assets_root, encoder_label, args.dataset) | |
| if checkpoint_path is None: | |
| raise FileNotFoundError( | |
| f"No checkpoint found for encoder='{encoder_label}' dataset='{args.dataset}'. Provide --checkpoint." | |
| ) | |
| datasets_dir = args.datasets_dir or (args.assets_root / "datasets") | |
| beir_source = load_beir_source(args.dataset, datasets_dir=datasets_dir, max_queries=args.max_queries) | |
| test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, args.device) | |
| model, metadata, missing, unexpected = load_model( | |
| checkpoint_path=checkpoint_path, | |
| model_config=ModelConfig( | |
| input_dim=encoder_spec.embedding_dim, | |
| output_dim=args.output_dim, | |
| hidden_dim=args.hidden_dim, | |
| dropout=args.dropout, | |
| ), | |
| device=args.device, | |
| ) | |
| metrics = evaluate_model( | |
| model=model, | |
| cached_split=test_split, | |
| device=args.device, | |
| feedback_k=args.feedback_k, | |
| ranking_k=args.ranking_k, | |
| k_values=_k_values(args), | |
| ) | |
| print( | |
| json.dumps( | |
| { | |
| "checkpoint": str(checkpoint_path), | |
| "metrics": metrics, | |
| "metadata": metadata, | |
| "missing_keys": missing, | |
| "unexpected_keys": unexpected, | |
| }, | |
| indent=2, | |
| ) | |
| ) | |
| return 0 | |
| def _command_run(args: argparse.Namespace) -> int: | |
| encoder_spec = _resolve_encoder_spec(args) | |
| encoder_label = _encoder_label(args, encoder_spec) | |
| cache_dir = args.cache_dir or (args.assets_root / f"cache_{encoder_label}_{args.dataset}") | |
| if not cache_dir.exists(): | |
| cache_args = argparse.Namespace(**vars(args)) | |
| cache_args.cache_dir = cache_dir | |
| _command_cache(cache_args) | |
| train_args = argparse.Namespace(**vars(args)) | |
| train_args.cache_dir = cache_dir | |
| return _command_train(train_args) | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Train and evaluate IMRNNs over cached BEIR embeddings.") | |
| subparsers = parser.add_subparsers(dest="command", required=True) | |
| list_assets = subparsers.add_parser("list-assets", help="List cached embeddings and checkpoints.") | |
| list_assets.add_argument("--assets-root", type=Path, default=default_assets_root()) | |
| list_assets.set_defaults(func=_command_list_assets) | |
| train = subparsers.add_parser("train", help="Train IMRNNs from cached embeddings.") | |
| _add_common_args(train) | |
| train.add_argument("--cache-dir", type=Path) | |
| train.add_argument("--output-dir", type=Path) | |
| train.add_argument("--max-queries", type=int) | |
| train.add_argument("--batch-size", type=int, default=32) | |
| train.add_argument("--epochs", type=int, default=10) | |
| train.add_argument("--lr", type=float, default=1e-4) | |
| train.add_argument("--weight-decay", type=float, default=1e-5) | |
| train.add_argument("--num-negatives", type=int, default=20) | |
| train.add_argument("--output-dim", type=int, default=256) | |
| train.add_argument("--hidden-dim", type=int, default=128) | |
| train.add_argument("--dropout", type=float, default=0.1) | |
| train.add_argument("--feedback-k", type=int, default=100) | |
| train.add_argument("--ranking-k", type=int, default=10) | |
| train.add_argument("--k", type=int, default=10) | |
| train.set_defaults(func=_command_train) | |
| evaluate = subparsers.add_parser("evaluate", help="Evaluate an IMRNN checkpoint.") | |
| _add_common_args(evaluate) | |
| evaluate.add_argument("--cache-dir", type=Path) | |
| evaluate.add_argument("--checkpoint", type=Path) | |
| evaluate.add_argument("--max-queries", type=int) | |
| evaluate.add_argument("--output-dim", type=int, default=256) | |
| evaluate.add_argument("--hidden-dim", type=int, default=128) | |
| evaluate.add_argument("--dropout", type=float, default=0.1) | |
| evaluate.add_argument("--feedback-k", type=int, default=100) | |
| evaluate.add_argument("--ranking-k", type=int, default=10) | |
| evaluate.add_argument("--k", type=int, default=10) | |
| evaluate.set_defaults(func=_command_evaluate) | |
| cache = subparsers.add_parser("cache", help="Download a BEIR dataset and cache embeddings plus negatives.") | |
| _add_common_args(cache) | |
| cache.add_argument("--cache-dir", type=Path) | |
| cache.add_argument("--batch-size", type=int, default=64) | |
| cache.add_argument("--num-negatives", type=int, default=20) | |
| cache.add_argument("--negative-pool", type=int, default=200) | |
| cache.set_defaults(func=_command_cache) | |
| run = subparsers.add_parser("run", help="Cache embeddings if needed, then train and evaluate IMRNNs end to end.") | |
| _add_common_args(run) | |
| run.add_argument("--cache-dir", type=Path) | |
| run.add_argument("--output-dir", type=Path) | |
| run.add_argument("--max-queries", type=int) | |
| run.add_argument("--batch-size", type=int, default=32) | |
| run.add_argument("--epochs", type=int, default=10) | |
| run.add_argument("--lr", type=float, default=1e-4) | |
| run.add_argument("--weight-decay", type=float, default=1e-5) | |
| run.add_argument("--num-negatives", type=int, default=20) | |
| run.add_argument("--negative-pool", type=int, default=200) | |
| run.add_argument("--output-dim", type=int, default=256) | |
| run.add_argument("--hidden-dim", type=int, default=128) | |
| run.add_argument("--dropout", type=float, default=0.1) | |
| run.add_argument("--feedback-k", type=int, default=100) | |
| run.add_argument("--ranking-k", type=int, default=10) | |
| run.add_argument("--k", type=int, default=10) | |
| run.set_defaults(func=_command_run) | |
| return parser | |
| def main() -> int: | |
| parser = build_parser() | |
| args = parser.parse_args() | |
| return args.func(args) | |