Spaces:
Sleeping
Sleeping
| """Utility methods to define and validate flags.""" | |
| import argparse | |
| import logging | |
| import os | |
| import re | |
| from typing import Callable | |
| import pkg_resources | |
| from configargparse import ArgumentParser | |
| from sage.reranker import RerankerProvider | |
| # Limits defined here: https://ai.google.dev/gemini-api/docs/models/gemini | |
| GEMINI_MAX_TOKENS_PER_CHUNK = 2048 | |
| MARQO_MAX_CHUNKS_PER_BATCH = 64 | |
| # The ADA embedder from OpenAI has a maximum of 8192 tokens. | |
| OPENAI_MAX_TOKENS_PER_CHUNK = 8192 | |
| # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch. | |
| OPENAI_MAX_CHUNKS_PER_BATCH = 2048 | |
| # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once. | |
| OPENAI_MAX_TOKENS_PER_JOB = 3_000_000 | |
| # Note that OpenAI embedding models have fixed dimensions, however, taking a slice of them is possible. | |
| # See "Reducing embedding dimensions" under https://platform.openai.com/docs/guides/embeddings/use-cases and | |
| # https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-dimensions | |
| OPENAI_DEFAULT_EMBEDDING_SIZE = { | |
| "text-embedding-ada-002": 1536, | |
| "text-embedding-3-small": 1536, | |
| "text-embedding-3-large": 3072, | |
| } | |
| VOYAGE_MAX_CHUNKS_PER_BATCH = 128 | |
| def get_voyage_max_tokens_per_batch(model: str) -> int: | |
| """Returns the maximum number of tokens per batch for the Voyage model. | |
| See https://docs.voyageai.com/reference/embeddings-api.""" | |
| if model == "voyage-3-lite": | |
| return 1_000_000 | |
| if model in ["voyage-3", "voyage-2"]: | |
| return 320_000 | |
| return 120_000 | |
| def get_voyage_embedding_size(model: str) -> int: | |
| """Returns the embedding size for the Voyage model. See https://docs.voyageai.com/docs/embeddings#model-choices.""" | |
| if model == "voyage-3-lite": | |
| return 512 | |
| if model == "voyage-2-code": | |
| return 1536 | |
| return 1024 | |
| def add_config_args(parser: ArgumentParser): | |
| """Adds configuration-related arguments to the parser.""" | |
| parser.add( | |
| "--mode", | |
| choices=["local", "remote"], | |
| default="local", | |
| help="Whether to use local-only resources or call third-party providers.", | |
| ) | |
| parser.add( | |
| "--config", | |
| is_config_file=True, | |
| help="Path to .yaml configuration file.", | |
| ) | |
| args, _ = parser.parse_known_args() | |
| config_file = pkg_resources.resource_filename(__name__, f"configs/{args.mode}.yaml") | |
| parser.set_defaults(config=config_file) | |
| return lambda _: True | |
| def add_repo_args(parser: ArgumentParser) -> Callable: | |
| """Adds repository-related arguments to the parser and returns a validator.""" | |
| parser.add("repo_id", help="The ID of the repository to index") | |
| parser.add("--commit-hash", help="Optional commit hash to checkout. When not provided, defaults to HEAD.") | |
| parser.add( | |
| "--local-dir", | |
| default="repos", | |
| help="The local directory to store the repository", | |
| ) | |
| return validate_repo_args | |
| def add_embedding_args(parser: ArgumentParser) -> Callable: | |
| """Adds embedding-related arguments to the parser and returns a validator.""" | |
| parser.add("--embedding-provider", default="marqo", choices=["openai", "voyage", "marqo", "gemini"]) | |
| parser.add( | |
| "--embedding-model", | |
| type=str, | |
| default=None, | |
| help="The embedding model. Defaults to `text-embedding-ada-002` for OpenAI and `hf/e5-base-v2` for Marqo.", | |
| ) | |
| parser.add( | |
| "--embedding-size", | |
| type=int, | |
| default=None, | |
| help="The embedding size to use for OpenAI text-embedding-3* models. Defaults to 1536 for small and 3072 for " | |
| "large. Note that no other OpenAI models support a dynamic embedding size, nor do models used with Marqo.", | |
| ) | |
| parser.add( | |
| "--tokens-per-chunk", | |
| type=int, | |
| default=800, | |
| help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.", | |
| ) | |
| parser.add( | |
| "--chunks-per-batch", | |
| type=int, | |
| help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.", | |
| ) | |
| parser.add( | |
| "--max-embedding-jobs", | |
| type=int, | |
| help="Maximum number of embedding jobs to run. Specifying this might result in " | |
| "indexing only part of the repository, but prevents you from burning through OpenAI credits.", | |
| ) | |
| return validate_embedding_args | |
| def add_vector_store_args(parser: ArgumentParser) -> Callable: | |
| """Adds vector store-related arguments to the parser and returns a validator.""" | |
| parser.add("--vector-store-provider", default="marqo", choices=["pinecone", "marqo"]) | |
| parser.add( | |
| "--pinecone-index-name", | |
| default=None, | |
| help="Pinecone index name. Required if using Pinecone as the vector store. If the index doesn't exist already, " | |
| "we will create it.", | |
| ) | |
| parser.add( | |
| "--index-namespace", | |
| default=None, | |
| help="Index namespace for this repo. When not specified, we default it to a derivative of the repo name.", | |
| ) | |
| parser.add( | |
| "--marqo-url", | |
| default="http://localhost:8882", | |
| help="URL for the Marqo server. Required if using Marqo as embedder or vector store.", | |
| ) | |
| parser.add( | |
| "--retrieval-alpha", | |
| default=1.0, | |
| type=float, | |
| help="Takes effect for Pinecone retriever only. The weight of the dense (embeddings-based) vs sparse (BM25) " | |
| "encoder in the final retrieval score. A value of 0.0 means BM25 only, 1.0 means embeddings only.", | |
| ) | |
| parser.add( | |
| "--retriever-top-k", default=25, type=int, help="The number of top documents to retrieve from the vector store." | |
| ) | |
| parser.add( | |
| "--multi-query-retriever", | |
| action=argparse.BooleanOptionalAction, | |
| default=False, | |
| help="When set to True, we rewrite the query 5 times, perform retrieval for each rewrite, and take the union " | |
| "of retrieved documents. See https://python.langchain.com/v0.1/docs/modules/data_connection/retrievers/MultiQueryRetriever/.", | |
| ) | |
| parser.add( | |
| "--llm-retriever", | |
| action=argparse.BooleanOptionalAction, | |
| default=False, | |
| help="When set to True, we use an LLM for retrieval: we pass the repository file hierarchy together with the " | |
| "user query and ask the LLM to choose relevant files solely based on their paths. No indexing will be done, so " | |
| "all the vector store / embedding arguments will be ignored.", | |
| ) | |
| return validate_vector_store_args | |
| def add_indexing_args(parser: ArgumentParser) -> Callable: | |
| """Adds indexing-related arguments to the parser and returns a validator.""" | |
| parser.add( | |
| "--include", | |
| help="Path to a file containing a list of extensions to include. One extension per line.", | |
| ) | |
| parser.add( | |
| "--exclude", | |
| help="Path to a file containing a list of extensions to exclude. One extension per line.", | |
| ) | |
| # Pass --no-index-repo in order to not index the repository. | |
| parser.add( | |
| "--index-repo", | |
| action=argparse.BooleanOptionalAction, | |
| default=True, | |
| help="Whether to index the repository. At least one of --index-repo and --index-issues must be True.", | |
| ) | |
| # Pass --no-index-issues in order to not index the issues. | |
| parser.add( | |
| "--index-issues", | |
| action=argparse.BooleanOptionalAction, | |
| default=False, | |
| help="Whether to index GitHub issues. At least one of --index-repo and --index-issues must be True. When " | |
| "--index-issues is set, you must also set a GITHUB_TOKEN environment variable.", | |
| ) | |
| # Pass --no-index-issue-comments in order to not index the comments of GitHub issues. | |
| parser.add( | |
| "--index-issue-comments", | |
| action=argparse.BooleanOptionalAction, | |
| default=False, | |
| help="Whether to index the comments of GitHub issues. This is only relevant if --index-issues is set. " | |
| "GitHub's API for downloading comments is quite slow. Indexing solely the body of an issue seems to bring most " | |
| "of the gains anyway.", | |
| ) | |
| return validate_indexing_args | |
| def add_reranking_args(parser: ArgumentParser) -> Callable: | |
| """Adds reranking-related arguments to the parser.""" | |
| parser.add("--reranker-provider", default="huggingface", choices=[r.value for r in RerankerProvider]) | |
| parser.add( | |
| "--reranker-model", | |
| help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the " | |
| "SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models", | |
| ) | |
| parser.add("--reranker-top-k", default=5, help="The number of top documents to return after reranking.") | |
| # Trivial validator (nothing to check). | |
| return lambda _: True | |
| def add_llm_args(parser: ArgumentParser) -> Callable: | |
| """Adds language model-related arguments to the parser.""" | |
| parser.add("--llm-provider", default="ollama", choices=["openai", "anthropic", "ollama"]) | |
| parser.add( | |
| "--llm-model", | |
| help="The LLM name. Must be supported by the provider specified via --llm-provider.", | |
| ) | |
| # Trivial validator (nothing to check). | |
| return lambda _: True | |
| def add_all_args(parser: ArgumentParser) -> Callable: | |
| """Adds all arguments to the parser and returns a validator.""" | |
| arg_validators = [ | |
| add_config_args(parser), | |
| add_repo_args(parser), | |
| add_embedding_args(parser), | |
| add_vector_store_args(parser), | |
| add_reranking_args(parser), | |
| add_indexing_args(parser), | |
| add_llm_args(parser), | |
| ] | |
| def validate_all(args): | |
| for validator in arg_validators: | |
| validator(args) | |
| return validate_all | |
| def validate_repo_args(args): | |
| """Validates the configuration of the repository.""" | |
| if not re.match(r"^[^/]+/[^/]+$", args.repo_id): | |
| raise ValueError("repo_id must be in the format 'owner/repo'") | |
| def _validate_openai_embedding_args(args): | |
| """Validates the configuration of the OpenAI batch embedder and sets defaults.""" | |
| if args.embedding_provider == "openai" and not os.getenv("OPENAI_API_KEY"): | |
| raise ValueError("Please set the OPENAI_API_KEY environment variable.") | |
| if not args.embedding_model: | |
| args.embedding_model = "text-embedding-3-small" | |
| if args.embedding_model not in OPENAI_DEFAULT_EMBEDDING_SIZE.keys(): | |
| raise ValueError(f"Unrecognized embeddings.model={args.embedding_model}") | |
| if not args.embedding_size: | |
| args.embedding_size = OPENAI_DEFAULT_EMBEDDING_SIZE.get(args.embedding_model) | |
| if not args.tokens_per_chunk: | |
| # https://arxiv.org/pdf/2406.14497 recommends a value between 200-800. | |
| args.tokens_per_chunk = 800 | |
| elif args.tokens_per_chunk > OPENAI_MAX_TOKENS_PER_CHUNK: | |
| args.tokens_per_chunk = OPENAI_MAX_TOKENS_PER_CHUNK | |
| logging.warning( | |
| f"OpenAI enforces a limit of {OPENAI_MAX_TOKENS_PER_CHUNK} tokens per chunk. " | |
| "Overwriting embeddings.tokens_per_chunk." | |
| ) | |
| if not args.chunks_per_batch: | |
| args.chunks_per_batch = OPENAI_MAX_CHUNKS_PER_BATCH | |
| elif args.chunks_per_batch > OPENAI_MAX_CHUNKS_PER_BATCH: | |
| args.chunks_per_batch = OPENAI_MAX_CHUNKS_PER_BATCH | |
| logging.warning( | |
| f"OpenAI enforces a limit of {OPENAI_MAX_CHUNKS_PER_BATCH} chunks per batch. " | |
| "Overwriting embeddings.chunks_per_batch." | |
| ) | |
| chunks_per_job = args.tokens_per_chunk * args.chunks_per_batch | |
| if chunks_per_job >= OPENAI_MAX_TOKENS_PER_JOB: | |
| raise ValueError(f"The maximum number of chunks per job is {OPENAI_MAX_TOKENS_PER_JOB}. Got {chunks_per_job}") | |
| def _validate_voyage_embedding_args(args): | |
| """Validates the configuration of the Voyage batch embedder and sets defaults.""" | |
| if args.embedding_provider == "voyage" and not os.getenv("VOYAGE_API_KEY"): | |
| raise ValueError("Please set the VOYAGE_API_KEY environment variable.") | |
| if not args.embedding_model: | |
| args.embedding_model = "voyage-code-2" | |
| if not args.tokens_per_chunk: | |
| # https://arxiv.org/pdf/2406.14497 recommends a value between 200-800. | |
| args.tokens_per_chunk = 800 | |
| if not args.chunks_per_batch: | |
| args.chunks_per_batch = VOYAGE_MAX_CHUNKS_PER_BATCH | |
| elif args.chunks_per_batch > VOYAGE_MAX_CHUNKS_PER_BATCH: | |
| args.chunks_per_batch = VOYAGE_MAX_CHUNKS_PER_BATCH | |
| logging.warning(f"Voyage enforces a limit of {VOYAGE_MAX_CHUNKS_PER_BATCH} chunks per batch. Overwriting.") | |
| max_tokens = get_voyage_max_tokens_per_batch(args.embedding_model) | |
| if args.tokens_per_chunk * args.chunks_per_batch > max_tokens: | |
| raise ValueError( | |
| f"Voyage enforces a limit of {max_tokens} tokens per batch. " | |
| "Reduce either --tokens-per-chunk or --chunks-per-batch." | |
| ) | |
| if not args.embedding_size: | |
| args.embedding_size = get_voyage_embedding_size(args.embedding_model) | |
| def _validate_marqo_embedding_args(args): | |
| """Validates the configuration of the Marqo batch embedder and sets defaults.""" | |
| if not args.embedding_model: | |
| args.embedding_model = "hf/e5-base-v2" | |
| if not args.chunks_per_batch: | |
| args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH | |
| elif args.chunks_per_batch > MARQO_MAX_CHUNKS_PER_BATCH: | |
| args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH | |
| logging.warning( | |
| f"Marqo enforces a limit of {MARQO_MAX_CHUNKS_PER_BATCH} chunks per batch. " | |
| "Overwriting embeddings.chunks_per_batch." | |
| ) | |
| def _validate_gemini_embedding_args(args): | |
| """Validates the configuration of the Gemini batch embedder and sets defaults.""" | |
| if not args.embedding_model: | |
| args.embedding_model = "models/text-embedding-004" | |
| assert os.environ[ | |
| "GOOGLE_API_KEY" | |
| ], "Please set the GOOGLE_API_KEY environment variable if using `gemini` embeddings." | |
| if not args.chunks_per_batch: | |
| # This value is reasonable but arbitrary (i.e. Gemini does not explicitly enforce a limit). | |
| args.chunks_per_batch = 2000 | |
| if not args.tokens_per_chunk: | |
| args.tokens_per_chunk = GEMINI_MAX_TOKENS_PER_CHUNK | |
| if not args.embedding_size: | |
| args.embedding_size = 768 | |
| def validate_embedding_args(args): | |
| """Validates the configuration of the batch embedder and sets defaults.""" | |
| if args.embedding_provider == "openai": | |
| _validate_openai_embedding_args(args) | |
| elif args.embedding_provider == "voyage": | |
| _validate_voyage_embedding_args(args) | |
| elif args.embedding_provider == "marqo": | |
| _validate_marqo_embedding_args(args) | |
| elif args.embedding_provider == "gemini": | |
| _validate_gemini_embedding_args(args) | |
| else: | |
| raise ValueError(f"Unrecognized --embedding-provider={args.embedding_provider}") | |
| def validate_vector_store_args(args): | |
| """Validates the configuration of the vector store and sets defaults.""" | |
| if args.llm_retriever: | |
| if not os.getenv("ANTHROPIC_API_KEY"): | |
| raise ValueError( | |
| "Please set the ANTHROPIC_API_KEY environment variable to use the LLM retriever. " | |
| "(We're constrained to Claude because we need prompt caching.)" | |
| ) | |
| if args.index_issues: | |
| # The LLM retriever only makes sense on the code repository, since it passes file paths to the LLM. | |
| raise ValueError("Cannot use --index-issues with --llm-retriever.") | |
| # When using an LLM retriever, all the vector store arguments are ignored. | |
| return | |
| if not args.index_namespace: | |
| # Attempt to derive a default index namespace from the repository information. | |
| if "repo_id" not in args: | |
| raise ValueError("Please set a value for --index-namespace.") | |
| args.index_namespace = args.repo_id | |
| if "commit_hash" in args and args.commit_hash: | |
| args.index_namespace += "/" + args.commit_hash | |
| if args.vector_store_provider == "marqo": | |
| # Marqo namespaces must match this pattern: [a-zA-Z_-][a-zA-Z0-9_-]* | |
| args.index_namespace = re.sub(r"[^a-zA-Z0-9_-]", "_", args.index_namespace) | |
| if args.vector_store_provider == "marqo": | |
| if not args.marqo_url: | |
| args.marqo_url = "http://localhost:8882" | |
| if "/" in args.index_namespace: | |
| raise ValueError(f"Marqo doesn't allow slashes in --index-namespace={args.index_namespace}.") | |
| elif args.vector_store_provider == "pinecone": | |
| if not os.getenv("PINECONE_API_KEY"): | |
| raise ValueError("Please set the PINECONE_API_KEY environment variable.") | |
| if not args.pinecone_index_name: | |
| raise ValueError(f"Please set the vector_store.pinecone_index_name value.") | |
| def validate_indexing_args(args): | |
| """Validates the indexing configuration and sets defaults.""" | |
| if args.include and args.exclude: | |
| raise ValueError("At most one of indexing.include and indexing.exclude can be specified.") | |
| if not args.include and not args.exclude: | |
| args.exclude = pkg_resources.resource_filename(__name__, "sample-exclude.txt") | |
| if args.include and not os.path.exists(args.include): | |
| raise ValueError(f"Path --include={args.include} does not exist.") | |
| if args.exclude and not os.path.exists(args.exclude): | |
| raise ValueError(f"Path --exclude={args.exclude} does not exist.") | |
| if not args.index_repo and not args.index_issues: | |
| raise ValueError("Either --index_repo or --index_issues must be set to true.") | |
| if args.index_issues and not os.getenv("GITHUB_TOKEN"): | |
| raise ValueError("Please set the GITHUB_TOKEN environment variable.") | |