juliaturc commited on
Commit
8699925
·
1 Parent(s): 9802b75

Retrieval benchmark (#39)

Browse files

* Fixes for previous PR

* Add retriever.py

* Add retrieval benchmark

* Add --retrieval-alpha flag

* Fit BM25 to the current corpus and add Voyage embeddings

* Add benchmark README.

* Nits to retrieve.py

* Add Voyage reranker.

* Update README to reflect Voyage embeddings and reranker.

* Address reviewer comments

README.md CHANGED
@@ -72,22 +72,28 @@ pip install git+https://github.com/Storia-AI/sage.git@main
72
  <details>
73
  <summary><strong>:cloud: Using external providers (higher quality)</strong></summary>
74
 
75
- 1. We support <a href="https://openai.com/">OpenAI</a> for embeddings (they have a super fast batch embedding API) and <a href="https://www.pinecone.io/">Pinecone</a> for the vector store. So you will need two API keys:
76
 
77
  ```
78
- export OPENAI_API_KEY=...
79
- export PINECONE_API_KEY=...
80
  ```
81
 
82
- 2. Create a Pinecone account. Export the desired index name (if it doesn't exist yet, we'll create it):
 
 
 
 
 
83
  ```
84
  export PINECONE_INDEX_NAME=...
85
  ```
86
 
87
- 3. For reranking, we use <a href="https://cohere.com/rerank">Cohere</a> by default, but you can also try rerankers from <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a> or <a href="https://jina.ai/reranker/">Jina</a>:
88
  ```
89
- export COHERE_API_KEY=... # or
90
  export NVIDIA_API_KEY=... # or
 
 
91
  export JINA_API_KEY=...
92
  ```
93
 
 
72
  <details>
73
  <summary><strong>:cloud: Using external providers (higher quality)</strong></summary>
74
 
75
+ 1. For embeddings, we support <a href="https://platform.openai.com/docs/guides/embeddings">OpenAI</a> and <a href="https://docs.voyageai.com/docs/embeddings">Voyage</a>. According to [our experiments](benchmarks/retrieval/README.md), OpenAI is better quality. Their batch API is also faster, with more generous rate limits. Export the API key of the desired provider:
76
 
77
  ```
78
+ export OPENAI_API_KEY=... # or
79
+ export VOYAGE_API_KEY=...
80
  ```
81
 
82
+ 2. We use <a href="https://www.pinecone.io/">Pinecone</a> for the vector store, so you will need an API key:
83
+
84
+ ```
85
+ export PINECONE_API_KEY=...
86
+ ```
87
+ If you want to reuse an existing Pinecone index, specify it. Otherwise we'll create a new one called `sage`.
88
  ```
89
  export PINECONE_INDEX_NAME=...
90
  ```
91
 
92
+ 3. For reranking, we support <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a>, <a href="https://docs.voyageai.com/docs/reranker">Voyage</a>, <a href="https://cohere.com/rerank">Cohere</a>, and <a href="https://jina.ai/reranker/">Jina</a>. According to [our experiments](benchmark/retrieval/README.md), NVIDIA performs best. Export the API key of the desired provider:
93
  ```
 
94
  export NVIDIA_API_KEY=... # or
95
+ export VOYAGE_API_KEY=... # or
96
+ export COHERE_API_KEY=... # or
97
  export JINA_API_KEY=...
98
  ```
99
 
benchmarks/retrieval/README.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chat-with-your-codebase: Retrieval Benchmark
2
+ When using this repository (which allows you to chat with your codebase in two commands), you are indirectly making a series of choices that greatly influence the quality of your AI copilot: chunking strategy, embeddings, retrieval algorithm, rerankers, etc.
3
+
4
+ Our role as maintainers is two-fold: to give you options/flexibility, but also to find good defaults. We're not here just to dump code on the Internet. We're here to *make it work*.
5
+
6
+ To make progress, we need a ladder to climb. That's why we partnered with our friends at [Morph Labs](https://morph.so) to produce a benchmark that will allow us to make informed decisions and measure progress. We will make it public soon, but if you really really can't wait, let us know at [founders@storia.ai](mailto:founders@storia.ai).
7
+
8
+ Here you will find our first learnings enabled by this dataset. We focused on proprietary APIs, but we're planning on extending experiments to open-source models as well.
9
+
10
+ #### TL;DR
11
+ - OpenAI's `text-embedding-3-small` embeddings perform best.
12
+ - NVIDIA's reranker outperforms Cohere, Voyage and Jina.
13
+ - Sparse retrieval (e.g. BM25) is actively hurting code retrieval if you have natural language files in your index (e.g. Markdown).
14
+ - Chunks of size 800 are ideal; going smaller has very marginal gains.
15
+ - Going beyond `top_k=25` for retrieval has diminishing returns.
16
+
17
+ And now, if you want to nerd out, here's a bunch of plots and stats.
18
+
19
+ ## Dataset
20
+ Our dataset consists of 1,000 `<question, answer, relevant_documents>` pairs that focus on Hugging Face's [Transformers](https://github.com/huggingface/transformers) library.
21
+
22
+ The dataset was generated artificially and checked for quality by humans (we collaborated with [Morph Labs](https://morph.so)). The questions were designed to require context from 1-3 different Python files in order to be answered correctly.
23
+
24
+ A sample of 10 instances is provided in [sample.json](sample.json).
25
+
26
+ ### Code Retrieval Benchmark
27
+ Here, we will be using `<question, relevant_documents>` pairs as a code retrieval benchmark. For instance:
28
+ ```
29
+ - Question:
30
+ When developing a specialized model class in the Transformers library, how does `auto_class_update` ensure that the new class's methods are tailored specifically for its requirements while preserving the functionality of the original methods from the base class?
31
+
32
+ - Relevant documents:
33
+ huggingface/transformers/src/transformers/models/auto/auto_factory.py
34
+ huggingface/transformers/src/transformers/utils/doc.py
35
+ ```
36
+
37
+ #### Why not use an already-established code retrieval benchmark?
38
+ Indeed, there are already comprehensive code retrieval benchmarks like [CoIR](https://arxiv.org/abs/2407.02883). In fact, the [CosQA](https://arxiv.org/abs/2105.13239) subset of this benchmark has a similar format to ours (text-to-code retrieval for web queries).
39
+
40
+ However, we designed our document space to be *an entire codebase*, as opposed to a set of isolated Python functions. A real-world codebase contains a variety of files, including ones that are distracting and get undeservedly selected by the retriever. For instance, dense retrievers tend to prefer short files. READMEs also tend to score high even when irrelevant, since they're written in natural language. Our benchmark is able to surface such behaviors. It also allows us to experiment with a variety of strategies like file chunking.
41
+
42
+ In the rest of this document, we'll be sharing a few initial learnings enabled by our benchmark.
43
+
44
+ ### Metrics
45
+
46
+ Throughout this report, we will use the following evaluation metrics, as implemented by the [ir-measures](https://ir-measur.es/en/latest/) library.
47
+ - [R-Precision](https://ir-measur.es/en/latest/measures.html#rprec): The precision at R, where R is the number of relevant documents for a given query. Since our queries have a variable number of relevant documents (1-3), this is a convenient metric.
48
+ - [Precision@1 (P@1)](https://ir-measur.es/en/latest/measures.html#p): Reflects how many of the documents retrieved on the first position are actually golden documents. Note that P@3 would be a misleading metric: since not all queries have 3 relevant documents, not even the golden dataset would score 100%.
49
+ - [Recall@3 (R@3)](https://ir-measur.es/en/latest/measures.html#r): Reflects how many of the golden documents were retrieved by the system. Note that R@1 would be a misleading metric: since a query can have multiple equally-relevant documents, not even the golden dataset would score 100%.
50
+ - [Mean Reciprocal Rank (MRR)](https://ir-measur.es/en/latest/measures.html#rr): For each query, takes the first golden document and looks up its rank in the retrieved documents. For instance, if the first golden document is retrieved second, the score for this query is 1/2. Note this metric is somewhat incomplete for our benchmark, because we might have multiple relevant documents.
51
+
52
+ ## Embeddings
53
+ :classical_building: **Verdict**: Use OpenAI's `text-embedding-3-small` embeddings.
54
+
55
+ Today, most retrieval systems are *dense*. They pre-compute document *embeddings* and store them in an index. At inference time, queries are also mapped to the same embedding space. In this world, retrieval is equivalent to finding the nearest neighbors of the query embedding in the index.
56
+
57
+ To this end, the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard) (Massive Text Embeddings Benchmark) offers a comprehensive comparison for open-source embeddings.
58
+
59
+ To complement this, we compared proprietary embedding APIs from [OpenAI](https://platform.openai.com/docs/guides/embeddings) and [Voyage](https://docs.voyageai.com/docs/embeddings). The main advantage of using these providers (in addition to quality) is that they provide *batch* embedding APIs, so you can get an entire repository indexed relatively quickly without the headache of hosting your own embedding models (you can do so with a simple `sage-index $GITHUB_REPO` command).
60
+
61
+ ![embeddings-plot](assets/embeddings.png)
62
+
63
+ The plot above shows the performance of the three types of embeddings from OpenAI (`text-embedding-3-small`, `text-embedding-3-large`, `text-embedding-ada-002`) and the code-specific embeddings from Voyage (`voyage-code-2`).
64
+
65
+ #### Experiment settings
66
+
67
+ - File chunks of <= 800 tokens;
68
+ - Dense retriever (nearest neighbor according to cosine distance of embeddings);
69
+ - Retrieved `top_k=25`;
70
+ - Reranked documents using the [NVIDIA re-ranker](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html) and selected `top_k=3`.
71
+
72
+ #### Results
73
+
74
+ - Across most evaluation metrics, OpenAI's `text-embedding-3-small` performs best.
75
+ - It's remarkable that the `text-embedding-3-large` embeddings don't perform better, despite having double the size (3072 vs 1536).
76
+ - The older `text-embedding-ada-002` embeddings are trailing last with a huge gap in performance, so this is your call to update your pipeline if you haven't already.
77
+
78
+ ## Rerankers
79
+ :classical_building: **Verdict**: Use NVIDIA's reranker.
80
+
81
+ In a world with infinitely fast compute, we would perform retrieval by passing each `<query, document>` pair through a Transformer, allowing all the query tokens to attend to all the document tokens. However, this is prohibitively expensive.
82
+
83
+ In practice, all documents are embedded independently and stored in a vector database. Most retrieval systems are two-staged: (1) embed the query independently to find its top N nearest neighbor documents, and (2) re-encode all top N `<query, document>` pairs and select the top K scoring ones. The second stage is called *reranking*.
84
+
85
+ ![rerankers-plot](assets/rerankers.png)
86
+
87
+ While the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard) compares *open-source* embedding models based on their ability to rerank documents, we conducted experiments on the most popular *proprietary* APIs for reranking, including [NVIDIA](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html), [Voyage](https://docs.voyageai.com/docs/reranker), [Cohere](https://cohere.com/rerank) and [Jina](https://jina.ai/reranker/).
88
+
89
+ #### Experiment settings
90
+ - File chunks of <= 800 tokens;
91
+ - Dense retriever using OpenAI's `text-embedding-3-small` model;
92
+ - Retrieved `top_k=25` documents;
93
+ - Reranked documents and selected `top_k=3`.
94
+
95
+ #### Results
96
+ - Across all evaluation metrics, the highest performing rerankers are, in this order: [NVIDIA](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html), [Voyage](https://docs.voyageai.com/docs/reranker), [Cohere](https://cohere.com/rerank) and [Jina](https://jina.ai/reranker/).
97
+ - Not using a reranker at all completely tanks the performance.
98
+
99
+ ## Retrieval: Sparse vs Dense
100
+ :classical_building: **Verdict**: Use fully dense embeddings.
101
+
102
+ So far, we've been experimenting with purely *dense* retrieval. That is, documents are selected solely on the cosine distance between their embedding and the query embedding.
103
+
104
+ Before the emergence of deep learning, retrievers used to be *sparse*. Such retrievers (e.g. [TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) or [BM25](https://en.wikipedia.org/wiki/Okapi_BM25)) were based on vectors of word counts (the vector of a document has the length of the dictionary, with each entry showing how many times a token occurs in the document; the term *sparse* comes from the fact that most entries are 0).
105
+
106
+ Since sparse retrievers rely on exact string match, one might assume they come in handy when the query contains a relatively unique token (e.g. a class name) that occurs in a small number of documents.
107
+
108
+ At the intersection of dense and sparse retrievers, *hybrid* retrievers score documents by the weighted average of the dense and sparse scores.
109
+
110
+ ![retrievers-plot](assets/retrievers.png)
111
+
112
+ In the experiment above, we compared the three types of retrievers (dense, hybrid and sparse).
113
+
114
+ #### Experiment settings
115
+ - File chunks of <= 800 tokens;
116
+ - For the dense and hybrid retrievers, we used OpenAI's `text-embedding-3-small` model for embeddings;
117
+ - Retrieved `top_k=25` documents;
118
+ - Reranked documents using the [NVIDIA re-ranker](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html) and selected `top_k=3`.
119
+
120
+ #### Results
121
+ Somewhat surprisingly, sparse retrieval is actively hurting performance. The reason is that exact string matching will favor files that are in natural language (and therefore match the token distribution in the query).
122
+
123
+ The plot below shows what percentage of the retrieved files are in Markdown. The purely sparse retriever chooses a Markdown file 40% of the time! Remember that we designed our questions so that the required context are Python files. This doesn't preclude Markdown files from actually being helpful in answering some of the questions, but surely not to this degree.
124
+
125
+ ![markdown-plot](assets/markdown.png)
126
+
127
+ ## Chunk sizes
128
+ :classical_building: **Verdict**: 800 tokens per chunk works well
129
+
130
+ The [CodeRag paper](https://arxiv.org/pdf/2406.14497) suggests that the ideal chunk size is somewhere between 200-800 tokens. All our experiments above used 800 tokens per chunk. When experimenting with the other end of the spectrum, we saw very mild improvements from having smaller chunks. We believe that these marginal gains are not worth the increased indexing time (since we need to send 4x more queries to the batch embedding APIs).
131
+
132
+ ![chunks-plot](assets/chunks.png)
benchmarks/retrieval/assets/chunks.png ADDED
benchmarks/retrieval/assets/embeddings.png ADDED
benchmarks/retrieval/assets/markdown.png ADDED
benchmarks/retrieval/assets/rerankers.png ADDED
benchmarks/retrieval/assets/retrievers.png ADDED
benchmarks/retrieval/retrieve.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to call retrieval on a benchmark dataset.
2
+
3
+ Make sure to `pip install ir_measures` before running this script.
4
+ """
5
+
6
+ import json
7
+ import logging
8
+ import os
9
+ import time
10
+
11
+ import configargparse
12
+ from ir_measures import MAP, MRR, P, Qrel, R, Rprec, ScoredDoc, calc_aggregate, nDCG
13
+
14
+ import sage.config
15
+ from sage.retriever import build_retriever_from_args
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger()
19
+ logger.setLevel(logging.INFO)
20
+
21
+
22
+ def main():
23
+ parser = configargparse.ArgParser(
24
+ description="Runs retrieval on a benchmark dataset.", ignore_unknown_config_file_keys=True
25
+ )
26
+ parser.add("--benchmark", required=True, help="Path to the benchmark dataset.")
27
+ parser.add(
28
+ "--gold-field", default="context_files", help="Field in the benchmark dataset that contains the golden answers."
29
+ )
30
+ parser.add(
31
+ "--question-field", default="question", help="Field in the benchmark dataset that contains the questions."
32
+ )
33
+ parser.add(
34
+ "--logs-dir",
35
+ default=None,
36
+ help="Path where to output predictions and metrics. Optional, since metrics are also printed to console."
37
+ )
38
+ parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
39
+
40
+ sage.config.add_config_args(parser)
41
+ sage.config.add_embedding_args(parser)
42
+ sage.config.add_vector_store_args(parser)
43
+ sage.config.add_reranking_args(parser)
44
+ args = parser.parse_args()
45
+ sage.config.validate_vector_store_args(args)
46
+
47
+ retriever = build_retriever_from_args(args)
48
+
49
+ with open(args.benchmark, "r") as f:
50
+ benchmark = json.load(f)
51
+ if args.max_instances is not None:
52
+ benchmark = benchmark[: args.max_instances]
53
+
54
+ golden_docs = [] # List of ir_measures.Qrel objects
55
+ retrieved_docs = [] # List of ir_measures.ScoredDoc objects
56
+
57
+ for question_idx, item in enumerate(benchmark):
58
+ print(f"Processing question {question_idx}...")
59
+
60
+ query_id = str(question_idx) # Solely needed for ir_measures library.
61
+
62
+ for golden_filepath in item[args.gold_field]:
63
+ # All the file paths in the golden answer are equally relevant for the query (i.e. the order is irrelevant),
64
+ # so we set relevance=1 for all of them.
65
+ golden_docs.append(Qrel(query_id=query_id, doc_id=golden_filepath, relevance=1))
66
+
67
+ # Make a retrieval call for the current question.
68
+ retrieved = retriever.invoke(item[args.question_field])
69
+ item["retrieved"] = []
70
+ for doc_idx, doc in enumerate(retrieved):
71
+ # The absolute value of the scores below does not affect the metrics; it merely determines the ranking of
72
+ # the retrived documents. The key of the score varies depending on the underlying retriever. If there's no
73
+ # score, we use 1/(doc_idx+1) since it preserves the order of the documents.
74
+ score = doc.metadata.get("score", doc.metadata.get("relevance_score", 1 / (doc_idx + 1)))
75
+ retrieved_docs.append(
76
+ ScoredDoc(query_id=query_id, doc_id=doc.metadata["file_path"], score=score)
77
+ )
78
+ # Update the output dictionary with the retrieved documents.
79
+ item["retrieved"].append({"file_path": doc.metadata["file_path"], "score": score})
80
+
81
+ if "answer" in item:
82
+ item.pop("answer") # Makes the output file harder to read.
83
+
84
+ print("Calculating metrics...")
85
+ results = calc_aggregate([Rprec, P @ 1, R @ 3, nDCG @ 3, MAP, MRR], golden_docs, retrieved_docs)
86
+ results = {str(key): value for key, value in results.items()}
87
+
88
+ if args.logs_dir:
89
+ if not os.path.exists(args.logs_dir):
90
+ os.makedirs(args.logs_dir)
91
+
92
+ out_data = {
93
+ "data": benchmark,
94
+ "metrics": results,
95
+ "flags": vars(args), # For reproducibility.
96
+ }
97
+
98
+ output_file = os.path.join(args.logs_dir, f"{time.time()}.json")
99
+ with open(output_file, "w") as f:
100
+ json.dump(out_data, f, indent=4)
101
+
102
+ for key in sorted(results.keys()):
103
+ print(f"{key}: {results[key]}")
104
+ print(f"Predictions and metrics saved to {output_file}")
105
+
106
+
107
+ if __name__ == "__main__":
108
+ main()
benchmarks/retrieval/sample.json ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "repo": "huggingface/transformers",
4
+ "commit": "7bb1c99800d235791dace10305731f377db8077b",
5
+ "context_files": [
6
+ "huggingface/transformers/src/transformers/commands/serving.py",
7
+ "huggingface/transformers/src/transformers/pipelines/__init__.py"
8
+ ],
9
+ "question": "With the introduction of a new translation service for \"en_to_es\", how does `serve_command_factory` ensure the server is prepared to handle this specific task efficiently?",
10
+ "answer": "The `serve_command_factory` function is designed to dynamically configure and deploy a server that can handle a variety of tasks, including complex and parameterized tasks such as language-specific translations. When a new translation service for \"en_to_es\" is introduced, the function ensures efficient handling through several key steps and mechanisms:\n\n1. **Task Specification and Validation**: The function begins by accepting command line arguments (`args`) which specify the task among other parameters. The task, in this case \"translation_en_to_es\", is passed to the `pipeline` function. This function internally calls `check_task` to validate and process the task string. The `check_task` function checks if the task is recognized and supported, and it returns structured data that includes default configurations and any task-specific options necessary for the task.\n\n2. **Dynamic Pipeline Configuration**: The `pipeline` function uses the information returned by `check_task` to dynamically configure a pipeline suitable for the specified task. This includes loading the appropriate model, tokenizer, and any additional components like feature extractors or image processors that are required for the task. For translation tasks, this might involve loading a specific model that is optimized for the language pair \"en_to_es\".\n\n3. **Model and Tokenizer Loading**: Depending on the task and the provided arguments, the `pipeline` function loads the necessary model and tokenizer. If the model or tokenizer is not explicitly specified in the command line arguments, the function will load default models and tokenizers that are appropriate for the task. This ensures that the server is using the most suitable tools for the translation between English and Spanish.\n\n4. **Device and Performance Configuration**: The `pipeline` function also handles device allocation (e.g., CPU, GPU) based on the provided arguments, ensuring that the model runs on the optimal hardware. This is crucial for performance, especially when handling complex tasks like real-time translation.\n\n5. **Server Initialization**: Once the pipeline is configured, `serve_command_factory` initializes a `ServeCommand` with the configured pipeline and server parameters such as host, port, and number of workers. This command encapsulates all the necessary settings to run the server efficiently.\n\n6. **Adaptability and Scalability**: The design of `serve_command_factory` and the underlying `pipeline` function allows for easy adaptation to new tasks and changes in task requirements. Adding support for a new translation like \"en_to_es\" primarily involves ensuring that the `check_task` function recognizes the task and that appropriate models and tokenizers are available and correctly configured.\n\nIn summary, `serve_command_factory` ensures that the server is prepared to handle the \"en_to_es\" translation task efficiently by dynamically configuring the server based on the specified task, validating and processing the task to load appropriate models and tools, and setting up the server with optimal hardware and network settings. This process guarantees that the server can adapt to new and complex tasks with minimal manual intervention.",
11
+ "expected_identifiers": [
12
+ "args",
13
+ "check_task",
14
+ "pipeline",
15
+ "ServeCommand"
16
+ ]
17
+ },
18
+ {
19
+ "repo": "huggingface/transformers",
20
+ "commit": "7bb1c99800d235791dace10305731f377db8077b",
21
+ "context_files": [
22
+ "huggingface/transformers/src/transformers/models/esm/modeling_esmfold.py",
23
+ "huggingface/transformers/src/transformers/models/esm/openfold_utils/feats.py",
24
+ "huggingface/transformers/src/transformers/models/esm/openfold_utils/tensor_utils.py"
25
+ ],
26
+ "question": "In a high-throughput setting where multiple protein structures are processed simultaneously, how does `EsmForProteinFolding.output_to_pdb` ensure accurate and independent structural representation in the resulting PDB files?",
27
+ "answer": "In a high-throughput setting where multiple protein structures are processed simultaneously, the function `output_to_pdb` ensures accurate and independent structural representation in the resulting PDB files through a combination of specialized tensor operations and careful indexing. This is achieved primarily through the use of the `atom14_to_atom37` function, which itself relies on the `batched_gather` function to correctly map atom positions from a simplified model output to a more detailed atomic representation.\n\n### Detailed Workflow:\n\n1. **Batch Processing and Tensor Operations**:\n - The `output_to_pdb` function begins by converting all tensor data to the CPU and converting them to NumPy arrays for easier manipulation. This step is crucial for performance and compatibility with subsequent operations that may not be optimized for GPU tensors.\n\n2. **Mapping Atom Positions**:\n - The function `atom14_to_atom37` is called within `output_to_pdb`. This function is responsible for expanding the reduced atom representation (14 atoms per amino acid) to a fuller representation (37 atoms per amino acid). It uses the `batched_gather` function to achieve this mapping accurately across potentially multiple proteins in a batch.\n\n3. **Complex Indexing with `batched_gather`**:\n - `batched_gather` plays a critical role in ensuring that the atom positions are mapped correctly. It constructs a complex indexing tuple that combines batch indices with the provided indices for gathering (`inds`). This tuple (`ranges`) includes both batch dimensions and the specific indices where atoms need to be gathered from the `atom14` tensor.\n - The use of `ranges` in `batched_gather` ensures that each protein's data is handled independently, preventing any cross-contamination or mixing of data between different proteins in the batch. This is crucial for maintaining the structural integrity of each protein.\n\n4. **Application of Mask and Final Adjustments**:\n - After mapping the positions, `atom14_to_atom37` applies a mask (`batch[\"atom37_atom_exists\"]`) to ensure that only existing atoms are considered. This step further ensures the accuracy of the structural data by zeroing out positions of non-existent atoms, preventing any erroneous data from affecting the structural representation.\n\n5. **Generation of PDB Data**:\n - Back in `output_to_pdb`, for each protein in the batch, an instance of `OFProtein` is created with the mapped atom positions, types, and other relevant data. The `to_pdb` function is then used to convert these protein data into the PDB format, ready for downstream applications like molecular dynamics simulations.\n\n### Conclusion:\n\nThrough the careful use of tensor operations, complex indexing, and data masking, `output_to_pdb` ensures that each protein's structural data is accurately and independently represented in the PDB outputs. This methodical approach is essential in high-throughput settings, where the accuracy and integrity of structural data are paramount for subsequent scientific analysis and applications.",
28
+ "expected_identifiers": [
29
+ "atom14_to_atom37",
30
+ "batched_gather",
31
+ "batch[\"atom37_atom_exists\"]",
32
+ "OFProtein"
33
+ ]
34
+ },
35
+ {
36
+ "repo": "huggingface/transformers",
37
+ "commit": "7bb1c99800d235791dace10305731f377db8077b",
38
+ "context_files": [
39
+ "huggingface/transformers/src/transformers/models/auto/auto_factory.py",
40
+ "huggingface/transformers/src/transformers/dynamic_module_utils.py"
41
+ ],
42
+ "question": "Following a security update in the production environment that limits internet connectivity, how does `_BaseAutoModelClass.from_pretrained` guarantee that the loaded model adheres strictly to the predefined version and settings?",
43
+ "answer": "In the updated production environment with restricted internet connectivity, `_BaseAutoModelClass.from_pretrained` ensures that the model loaded adheres strictly to the predefined version and settings through several key mechanisms, primarily involving the management of model files and code via a version control system and secure access to private repositories.\n\n### Version Control and Revision Specification\n\nThe function leverages a version control system that allows users to specify exact revisions of the model or code they wish to use. This is evident in the handling of the `revision` parameter in functions like `get_cached_module_file` and `get_class_from_dynamic_module`. The `revision` parameter can accept any identifier allowed by git, such as a branch name, a tag name, or a commit id. This ensures that the exact version of the model or code that was tested and approved in other environments (like development or staging) is the same version being deployed in production.\n\nFor example, in the `get_cached_module_file` function, the `revision` parameter is used to fetch the specific version of a module file from a repository:\n```python\nresolved_module_file = cached_file(\n pretrained_model_name_or_path,\n module_file,\n cache_dir=cache_dir,\n force_download=force_download,\n proxies=proxies,\n resume_download=resume_download,\n local_files_only=local_files_only,\n token=token,\n revision=revision,\n repo_type=repo_type,\n _commit_hash=_commit_hash,\n)\n```\n\n### Secure Access to Private Repositories\n\nThe function can authenticate access to private repositories using tokens, which is crucial when operating in environments with strict security protocols. The `token` parameter, which can be set to a string or `True` (to use the token generated by `huggingface-cli login`), is used to authenticate HTTP requests for remote files. This is handled securely in both `get_cached_module_file` and `get_class_from_dynamic_module`, ensuring that only authorized users can access private model files or code.\n\nFor instance, in `get_class_from_dynamic_module`, the `token` parameter is used to authenticate and download the necessary module file:\n```python\nfinal_module = get_cached_module_file(\n repo_id,\n module_file + \".py\",\n cache_dir=cache_dir,\n force_download=force_download,\n resume_download=resume_download,\n proxies=proxies,\n token=token,\n revision=code_revision,\n local_files_only=local_files_only,\n repo_type=repo_type,\n)\n```\n\n### Handling Restricted Internet Connectivity\n\nIn environments with limited internet access, the `local_files_only` parameter becomes particularly important. This parameter, when set to `True`, forces the function to only look for model files locally and not attempt to download them from the internet. This is crucial for ensuring that the model loading process does not fail due to lack of internet access and adheres to strict security protocols that might block external internet connections.\n\n### Conclusion\n\nBy utilizing these mechanisms, `_BaseAutoModelClass.from_pretrained` ensures that the model loaded in a production environment with restricted internet access is exactly the version specified, using secure and authenticated access where necessary. This approach guarantees consistency, reproducibility, and adherence to security protocols across different environments.",
44
+ "expected_identifiers": [
45
+ "revision",
46
+ "token",
47
+ "local_files_only"
48
+ ]
49
+ },
50
+ {
51
+ "repo": "huggingface/transformers",
52
+ "commit": "7bb1c99800d235791dace10305731f377db8077b",
53
+ "context_files": [
54
+ "huggingface/transformers/src/transformers/models/auto/auto_factory.py",
55
+ "huggingface/transformers/src/transformers/utils/doc.py"
56
+ ],
57
+ "question": "When developing a specialized model class in the Transformers library, how does `auto_class_update` ensure that the new class's methods are tailored specifically for its requirements while preserving the functionality of the original methods from the base class?",
58
+ "answer": "In the Transformers library, the `auto_class_update` function plays a crucial role in dynamically creating specialized model classes that inherit functionalities from a base class but also have unique customizations. This is particularly important when different model classes need specific configurations or preprocessing steps that are not shared across all models.\n\nThe core mechanism that allows `auto_class_update` to achieve this functionality without altering the behavior of the base class methods lies in its use of the `copy_func` function. Here's how it works step-by-step:\n\n1. **Copying the Function**: `copy_func` is used to create an exact copy of the methods `from_config` and `from_pretrained` from the base class `_BaseAutoModelClass`. This is done by duplicating the `__code__` object of these methods. The `__code__` object contains the compiled executable code that the Python interpreter runs. By copying this code object, the new function retains the exact behavior and logic of the original function.\n\n2. **Customization of the Copied Function**: After copying, `auto_class_update` modifies the docstrings of these methods to tailor them to the specific subclass. This involves inserting a specific `head_doc`, replacing placeholders like `\"BaseAutoModelClass\"` with the subclass's name, and updating example checkpoints specific to the model type (e.g., `\"google-bert/bert-base-cased\"`). These modifications are crucial for providing accurate and relevant documentation and guidance specific to each subclass.\n\n3. **Re-assignment as Class Methods**: Once the functions are copied and customized, they are re-assigned to the subclass as class methods. This is done using `classmethod(from_config)` and `classmethod(from_pretrained)`. This step ensures that these methods, now tailored and documented specifically for the subclass, are callable on the subclass itself.\n\n4. **Preservation of Base Class Functionality**: Since the original methods are copied before being modified, the base class `_BaseAutoModelClass` retains its original `from_config` and `from_pretrained` methods without any changes. This isolation ensures that modifications specific to one subclass do not impact the behavior or documentation of these methods in the base class or any other subclasses.\n\nBy following this process, `auto_class_update` ensures that each subclass in the Transformers library can have methods that are specifically tailored to its requirements, both in terms of functionality and documentation, while preserving the integrity and functionality of the original methods from the base class. This approach enhances modularity and flexibility in the library, allowing developers to easily extend and customize model classes for various use cases.",
59
+ "expected_identifiers": [
60
+ "__code__",
61
+ "copy_func",
62
+ "from_config",
63
+ "from_pretrained"
64
+ ]
65
+ },
66
+ {
67
+ "repo": "huggingface/transformers",
68
+ "commit": "7bb1c99800d235791dace10305731f377db8077b",
69
+ "context_files": [
70
+ "huggingface/transformers/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py",
71
+ "huggingface/transformers/src/transformers/modeling_utils.py"
72
+ ],
73
+ "question": "Given a system limitation of 5GB per file, how does `convert_checkpoint_from_megatron_to_transformers` manage the storage of a large model's data to comply with this restriction?",
74
+ "answer": "The `convert_checkpoint_from_megatron_to_transformers` function manages the storage of a large model's data to comply with a system limitation of 5GB per file by utilizing the `shard_checkpoint` function to split the model's state dictionary into multiple sub-checkpoints, each of which does not exceed the specified maximum size.\n\nHere's a detailed breakdown of how this is achieved:\n\n1. **Sharding Process**: The `shard_checkpoint` function is called within `convert_checkpoint_from_megatron_to_transformers` to handle the division of the model's weights into smaller parts or shards. This function takes the entire state dictionary of the model (`output_state_dict`) and a maximum shard size as inputs.\n\n2. **Size Calculation**: The function calculates the byte size of each tensor in the state dictionary using the `dtype_byte_size` function. This function determines the number of bytes each element of a tensor occupies in memory, based on the tensor's data type (`dtype`). This calculation is crucial as it helps in accurately assessing how much space each tensor will take when saved as part of a shard.\n\n3. **Iterative Sharding**: The `shard_checkpoint` iterates through each tensor in the state dictionary and adds them to the current shard until adding another tensor would exceed the maximum shard size (5GB in this scenario). When this limit is reached, a new shard is started. This ensures that no individual shard file exceeds the specified size limit.\n\n4. **Handling Oversized Tensors**: If a single tensor is larger than the maximum shard size, it is placed in its own shard. This is a necessary exception to prevent the function from failing due to an inability to split a tensor.\n\n5. **Saving Shards**: Each shard is saved as a separate file. The naming convention and indexing ensure that each part of the model can be identified and accessed correctly. The function also generates an index file if the model is split into multiple shards, detailing where each parameter is stored.\n\n6. **Parameter Mapping**: The function maintains a mapping (`weight_map`) of model parameters to their respective shard files. This mapping is crucial for efficiently loading the model from its sharded state.\n\nBy following these steps, the `convert_checkpoint_from_megatron_to_transformers` function ensures that each shard of the converted model adheres to the 5GB file size limit imposed by the system. This methodical sharding allows for efficient storage and handling of large models without exceeding system file size limitations.",
75
+ "expected_identifiers": [
76
+ "shard_checkpoint",
77
+ "dtype_byte_size",
78
+ "output_state_dict",
79
+ "weight_map"
80
+ ]
81
+ },
82
+ {
83
+ "repo": "huggingface/transformers",
84
+ "commit": "7bb1c99800d235791dace10305731f377db8077b",
85
+ "context_files": [
86
+ "huggingface/transformers/src/transformers/quantizers/quantizer_hqq.py",
87
+ "huggingface/transformers/src/transformers/integrations/hqq.py"
88
+ ],
89
+ "question": "In a scenario where a neural network model is being optimized for deployment, how does `HqqHfQuantizer._process_model_before_weight_loading` ensure that each linear module is appropriately and uniquely quantized?",
90
+ "answer": "In the scenario where a neural network model is being optimized for deployment using the `HqqHfQuantizer._process_model_before_weight_loading` function, the process of ensuring that each linear module is appropriately and uniquely quantized involves several key steps and functions.\n\n1. **Tagging Modules with Unique Identifiers**: The process begins with the `get_linear_tags` function, which is responsible for identifying and tagging all linear modules within the model. This function uses a `set` to collect the names of these modules, which inherently ensures that each tag is unique (since sets do not allow duplicates). This is crucial because it prevents any confusion or errors in later stages when quantization parameters are applied to these tags.\n\n2. **Applying Quantization Configuration**: Once the linear modules are tagged, the `prepare_for_hqq_linear` function takes over. This function receives a `quantization_config` and a list of modules not to convert. It first calls `autoname_modules` to ensure each module in the model has a unique name, and then retrieves the linear tags using `get_linear_tags`. The function then filters these tags to exclude any specified in `skip_modules` or `modules_to_not_convert`, ensuring that the quantization process is applied only to the relevant modules.\n\n3. **Mapping Quantization Parameters**: The core of the quantization process happens when `prepare_for_hqq_linear` maps the quantization parameters to each linear tag. This is done by creating a dictionary (`patch_params`) where each key is a linear tag and the value is the corresponding quantization parameter. If specific quantization parameters are not provided for a tag, a default configuration is applied. This mapping ensures that each linear module (identified uniquely by its tag) receives a tailored set of quantization parameters.\n\n4. **Updating Model Configuration**: After mapping the quantization parameters, the `prepare_for_hqq_linear` function updates the model's configuration to include these parameters, ensuring that each linear module's configuration reflects its unique quantization settings. This step is crucial for the actual quantization process, where linear modules might be replaced with their quantized counterparts (`HQQLinear`), depending on the configuration.\n\n5. **Final Verification and Logging**: The function checks if any linear modules have been replaced and logs a warning if no modules were found for quantization. This serves as a final check to ensure that the quantization process has been applied as expected.\n\nIn summary, the `HqqHfQuantizer._process_model_before_weight_loading` function ensures that each linear module is uniquely and appropriately quantized by meticulously tagging each module, applying a tailored quantization configuration, and updating the model to reflect these settings. This process is designed to optimize the model's performance for deployment, ensuring that each module operates efficiently and accurately under the constraints of quantization.",
91
+ "expected_identifiers": [
92
+ "get_linear_tags",
93
+ "autoname_modules",
94
+ "prepare_for_hqq_linear",
95
+ "patch_params"
96
+ ]
97
+ },
98
+ {
99
+ "repo": "huggingface/transformers",
100
+ "commit": "7bb1c99800d235791dace10305731f377db8077b",
101
+ "context_files": [
102
+ "huggingface/transformers/src/transformers/models/esm/modeling_esmfold.py",
103
+ "huggingface/transformers/src/transformers/models/esm/openfold_utils/loss.py"
104
+ ],
105
+ "question": "When analyzing a protein sequence with low complexity using `EsmForProteinFolding.forward`, how is the stability and definition of the output ensured?",
106
+ "answer": "When analyzing a protein sequence with low complexity using the `EsmForProteinFolding.forward` function, the stability and definition of the output are ensured through several key mechanisms embedded within the function's implementation, particularly in how it handles normalization and potential numerical instabilities.\n\n1. **Normalization of Residue Weights**: In the `compute_tm` function, residue weights are normalized by their sum, with the addition of a small constant `eps` (epsilon) to prevent division by zero. This is crucial when dealing with sequences of low complexity where certain residues might be overrepresented or underrepresented. The normalization step is represented in the code as:\n ```python\n normed_residue_mask = residue_weights / (eps + residue_weights.sum())\n ```\n Here, `eps` acts as a safeguard against division by zero, ensuring that the function remains numerically stable and produces defined outputs even when the sum of residue weights is extremely small or zero.\n\n2. **Weighted Average Calculation**: The function calculates a weighted average of the Template Modeling (TM) scores across different bins, which is critical for obtaining a reliable TM score. This is done using the normalized residue weights, ensuring that each residue's contribution is proportionate to its presence, thus maintaining accuracy and stability in the final score calculation:\n ```python\n per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)\n ```\n This step aggregates the TM scores across all residues, factoring in their normalized weights, which is particularly important in low complexity sequences where certain residues might dominate.\n\n3. **Handling of Edge Cases**: The use of `eps` in the normalization process is a direct method to handle edge cases, such as sequences with low complexity or unusual amino acid distributions. By ensuring that the denominator in the normalization step is never zero, the function avoids potential runtime errors (like NaN or infinite values), which could disrupt the analysis process.\n\n4. **Integration within `EsmForProteinFolding.forward`**: The stability and definition of outputs from the `EsmForProteinFolding.forward` function are further supported by how `compute_tm` integrates with other components of the model. The TM scores computed are used alongside other structural predictions, contributing to a comprehensive evaluation of the predicted protein structures. This integration ensures that the outputs are not only stable and defined but also meaningful in the context of protein structure prediction.\n\nIn summary, the `EsmForProteinFolding.forward` function ensures stable and defined outputs for protein structure predictions, particularly in scenarios involving low complexity sequences, by employing robust normalization techniques and handling potential numerical instabilities through the careful addition of a small epsilon value in critical calculations. This approach guarantees that the function can reliably process a wide range of input data without encountering computational errors.",
107
+ "expected_identifiers": [
108
+ "normed_residue_mask",
109
+ "eps",
110
+ "residue_weights / (eps + residue_weights.sum())",
111
+ "torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)"
112
+ ]
113
+ },
114
+ {
115
+ "repo": "huggingface/transformers",
116
+ "commit": "7bb1c99800d235791dace10305731f377db8077b",
117
+ "context_files": [
118
+ "huggingface/transformers/src/transformers/pipelines/question_answering.py",
119
+ "huggingface/transformers/src/transformers/data/processors/squad.py"
120
+ ],
121
+ "question": "In a scenario where the textual data includes unusually lengthy paragraphs, how does `QuestionAnsweringPipeline.preprocess` ensure comprehensive coverage of all context tokens in the model's input sequences?",
122
+ "answer": "In scenarios where the textual data includes unusually lengthy paragraphs that exceed the model's maximum input length, the `QuestionAnsweringPipeline.preprocess` function ensures comprehensive coverage of all context tokens in the model's input sequences through a meticulous management of tokenization and handling of overflow tokens. This process is crucial for maintaining the integrity and continuity of the context information, which is essential for the model to accurately answer questions based on the provided context.\n\n### Step-by-Step Explanation:\n\n1. **Tokenization and Pairing**:\n The function begins by tokenizing the question and context separately. Depending on the tokenizer's configuration (`tokenizer.padding_side`), the question and context are arranged in a specific order (either question first or context first). This is handled in the lines where `encoded_inputs` is defined using `self.tokenizer(text, text_pair, ...)`. \n\n2. **Handling Long Contexts with Overflow Tokens**:\n The key parameter here is `return_overflowing_tokens=True` within the tokenizer call. This setting ensures that when the combined length of the question and context exceeds `max_seq_len`, the tokenizer automatically generates additional input sequences that contain the \"overflow\" tokens from the context. These sequences overlap by a number of tokens defined by `doc_stride`, which is calculated as `min(max_seq_len // 2, 128)`.\n\n3. **Creating Overlapping Spans**:\n The overlapping spans are crucial for ensuring that tokens near the boundaries of a sequence are also seen in different contextual surroundings, enhancing the model's ability to understand and answer questions about tokens that appear near the maximum sequence length limit. This overlap is managed by the `stride` parameter in the tokenizer, which is set to `doc_stride`.\n\n4. **Feature Construction**:\n For each span generated from the overflowing tokens, the function constructs a feature object that includes not only the token ids (`input_ids`) but also attention masks, token type ids, and a special mask (`p_mask`) which indicates which tokens can be part of an answer. The `p_mask` is particularly important as it helps the model distinguish between context tokens (potential answer locations) and non-context tokens (like those belonging to the question or special tokens).\n\n5. **Yielding Processed Features**:\n Each feature constructed from the spans is then yielded one by one, with additional metadata such as whether it is the last feature of the example. This is handled in the loop `for i, feature in enumerate(features):` where each feature is prepared according to the model's requirements, potentially converting them into tensors suitable for the model's computation framework (PyTorch or TensorFlow).\n\n### Conclusion:\n\nBy managing the tokenization and overflow tokens effectively, `QuestionAnsweringPipeline.preprocess` ensures that every token from a lengthy context is included in at least one input sequence to the model. This comprehensive coverage is achieved through the creation of multiple, overlapping input sequences that ensure no contextual information is lost, thereby enabling the model to perform accurately even with contexts that far exceed its maximum input length. This methodical approach is crucial for handling real-world data where lengthy paragraphs are common, ensuring robust model performance across varied datasets.",
123
+ "expected_identifiers": [
124
+ "return_overflowing_tokens=True",
125
+ "doc_stride",
126
+ "p_mask"
127
+ ]
128
+ },
129
+ {
130
+ "repo": "huggingface/transformers",
131
+ "commit": "7bb1c99800d235791dace10305731f377db8077b",
132
+ "context_files": [
133
+ "huggingface/transformers/examples/research_projects/movement-pruning/masked_run_squad.py",
134
+ "huggingface/transformers/src/transformers/data/processors/squad.py"
135
+ ],
136
+ "question": "Given the challenge of training models on datasets with varying context lengths, how does `load_and_cache_examples` in `examples/research_projects/movement-pruning/masked_run_squad.py` ensure the production of a dataset that supports accurate answer position prediction by the models?",
137
+ "answer": "The `load_and_cache_examples` function in `examples/research_projects/movement-pruning/masked_run_squad.py` is designed to ensure that the dataset produced supports accurate answer position prediction by models, even when dealing with varying context lengths. This is achieved through a series of steps that involve careful handling and processing of the data, particularly when contexts exceed the model's maximum sequence length. Here's how the function manages this:\n\n1. **Data Loading and Caching**: Initially, the function checks if a cached version of the processed data exists. If it does, it loads the features, dataset, and examples directly from the cache, bypassing the need for reprocessing. This step ensures efficiency and consistency in data usage across different training runs.\n\n2. **Dynamic Data Processing**: If no cached data is available, the function processes the raw data to generate features suitable for model training. This involves tokenizing the text and converting the SQuAD examples into features that models can understand and learn from.\n\n3. **Handling Extended Contexts**: The core of handling varying context lengths lies in the `squad_convert_examples_to_features` function, which is called within `load_and_cache_examples`. This function uses `squad_convert_example_to_features` to process each example individually.\n\n4. **Segmentation and Token Index Adjustment**: In `squad_convert_example_to_features`, the context is potentially split into multiple spans if its length exceeds the model's maximum sequence length. This is crucial because it allows the model to handle long contexts by breaking them down into manageable parts. Each span is processed to ensure that the start and end positions of answers are correctly adjusted within the tokenized context. This adjustment is handled by the `_improve_answer_span` function, which ensures that the answer spans are accurately placed within the tokens, even if the context is segmented.\n\n5. **Feature Construction**: Each span is then converted into a set of features, including input IDs, attention masks, token type IDs, and the positions of the answers. Special care is taken to mark tokens that cannot be part of the answers (using a p_mask), and to identify the maximum context for each token, which is critical for understanding which part of the split context a token belongs to.\n\n6. **Dataset Compilation**: After processing, the features are compiled into a dataset format (either PyTorch or TensorFlow, based on the configuration). This dataset includes all necessary information for the model to learn from, including the context, the question, and the correct positions of the answers.\n\nBy carefully managing the tokenization, segmentation, and feature construction processes, `load_and_cache_examples` ensures that the dataset it produces allows models to accurately predict answer positions, regardless of the length of the context. This capability is essential for training robust question-answering models that can handle real-world data, where context lengths can vary significantly.",
138
+ "expected_identifiers": [
139
+ "squad_convert_examples_to_features",
140
+ "squad_convert_example_to_features",
141
+ "_improve_answer_span",
142
+ "p_mask"
143
+ ]
144
+ },
145
+ {
146
+ "repo": "huggingface/transformers",
147
+ "commit": "7bb1c99800d235791dace10305731f377db8077b",
148
+ "context_files": [
149
+ "huggingface/transformers/src/transformers/modeling_flax_utils.py",
150
+ "huggingface/transformers/src/transformers/utils/hub.py"
151
+ ],
152
+ "question": "In a scenario where network conditions are suboptimal, how does `FlaxPreTrainedModel.from_pretrained` manage to reduce the model loading time?",
153
+ "answer": "In scenarios where network conditions are suboptimal, the `FlaxPreTrainedModel.from_pretrained` function effectively reduces model loading time by leveraging a sophisticated caching mechanism. This mechanism is crucial for managing the download and storage of model shards, ensuring efficient and faster model initialization.\n\n### Caching Mechanism:\nThe function first checks if the required model shards are already available in the local cache before attempting any network requests. This is achieved through the `try_to_load_from_cache` function, which inspects the cache for the presence of the last shard of the model. If the last shard is found in the cache, it is likely that all previous shards are also cached, thus avoiding the need for further network requests.\n\n### Download and Cache Management:\nIf the shards are not found in the cache, `FlaxPreTrainedModel.from_pretrained` proceeds to download them. Each shard's presence is verified using the `cached_file` function, which handles the downloading and caching of the shard if it is not already present. This function also supports resuming downloads, which is particularly useful in suboptimal network conditions where downloads might be interrupted.\n\n### Efficient Shard Handling:\nThe function `get_checkpoint_shard_files` is specifically designed to manage sharded model files. It reads the checkpoint index file to determine all the necessary shards for the model and then ensures each shard is either fetched from the cache or downloaded. This process is streamlined by the use of a progress bar (managed by `tqdm`), which provides visual feedback on the download process, enhancing user experience especially in network-constrained environments.\n\n### Impact of Caching on Model Loading Time:\nBy prioritizing cached shards, `FlaxPreTrainedModel.from_pretrained` significantly reduces the dependency on network bandwidth and stability. This is particularly beneficial in scenarios with limited network resources, as it minimizes the time spent in downloading model components. The caching mechanism ensures that once a model shard is downloaded and stored locally, subsequent loads of the same model will utilize the cached versions, thereby bypassing the network entirely and leading to much faster model initialization times.\n\n### Conclusion:\nThe caching strategy employed by `FlaxPreTrainedModel.from_pretrained` not only optimizes the use of network resources but also ensures consistent and reduced model loading times, regardless of network conditions. This approach is instrumental in scenarios where models need to be switched frequently or reloaded, providing a seamless and efficient user experience.",
154
+ "expected_identifiers": [
155
+ "try_to_load_from_cache",
156
+ "cached_file",
157
+ "get_checkpoint_shard_files",
158
+ "tqdm"
159
+ ]
160
+ },
161
+ {
162
+ "repo": "huggingface/transformers",
163
+ "commit": "7bb1c99800d235791dace10305731f377db8077b",
164
+ "context_files": [
165
+ "huggingface/transformers/examples/research_projects/information-gain-filtration/run_clm_igf.py",
166
+ "huggingface/transformers/examples/research_projects/information-gain-filtration/igf/igf.py"
167
+ ],
168
+ "question": "In light of recent dataset size restrictions for training purposes, how does `generate_n_pairs` maintain compliance by ensuring the objective set adheres to the specified size and article length requirements?",
169
+ "answer": "The `generate_n_pairs` function ensures compliance with dataset size restrictions by meticulously managing the creation of the objective set through its subordinate function `generate_datasets`. This process is governed by specific parameters and conditions set within the code to meet the required criteria of size and article length.\n\n1. **Size of the Objective Set**: The function `generate_datasets` is designed to create an objective set that contains exactly the number of articles specified by the `number` parameter, which is passed from `generate_n_pairs` as `size_objective_set`. In the provided code, this value is set to 100. The loop within `generate_datasets` that populates the `objective_set` list includes a condition to break once the length of this list reaches the specified `number` (see the line `if len(objective_set) >= number: break`). This ensures that no more than 100 articles are added to the objective set, directly adhering to the dataset size restrictions.\n\n2. **Article Length Management**: The function also manages the length of each article in the objective set based on the `context_len` parameter. If `trim` is set to `True`, the function trims the articles to ensure they do not exceed the specified `context_len`. This is achieved by selecting a starting point randomly within the article and then slicing the article to obtain a segment of the specified `context_len` (see the line `objective_set.append(example[0, start : start + context_len])`). This ensures that each article in the objective set adheres to the length restrictions.\n\n3. **Compliance with Regulations**: By strictly controlling both the number of articles and their lengths as described, `generate_n_pairs` ensures that the objective set complies with new regulations requiring training datasets to contain no more than 100 articles, each of a specified maximum length. This compliance is crucial for ethical review and adherence to training dataset standards.\n\nIn summary, `generate_n_pairs` maintains compliance with dataset size and article length restrictions through careful implementation in `generate_datasets`, which explicitly controls the size of the objective set and trims articles to the required length based on the parameters provided. This methodical approach ensures that the objective set meets specified criteria, crucial for adhering to regulatory standards.",
170
+ "expected_identifiers": [
171
+ "generate_n_pairs",
172
+ "generate_datasets",
173
+ "size_objective_set",
174
+ "context_len"
175
+ ]
176
+ }
177
+ ]
requirements.txt CHANGED
@@ -1,18 +1,20 @@
1
  GitPython==3.1.43
2
  Pygments==2.18.0
3
  cohere==5.9.2
 
4
  fastapi==0.112.2
5
  gradio>=4.26.0
6
  langchain==0.2.16
7
  langchain-anthropic==0.1.23
8
  langchain-cohere==0.2.4
9
  langchain-community==0.2.17
10
- langchain-core==0.2.40
11
  langchain-experimental==0.0.65
12
  langchain-nvidia-ai-endpoints==0.2.2
13
  langchain-ollama==0.1.3
14
  langchain-openai==0.1.25
15
  langchain-text-splitters==0.2.4
 
16
  marqo==3.7.0
17
  nbformat==5.10.4
18
  openai==1.42.0
@@ -22,8 +24,10 @@ python-dotenv==1.0.1
22
  requests==2.32.3
23
  semchunk==2.2.0
24
  sentence-transformers==3.1.0
 
25
  tiktoken==0.7.0
26
  tokenizers==0.19.1
27
  transformers==4.44.2
28
  tree-sitter==0.22.3
29
  tree-sitter-language-pack==0.2.0
 
 
1
  GitPython==3.1.43
2
  Pygments==2.18.0
3
  cohere==5.9.2
4
+ configargparse
5
  fastapi==0.112.2
6
  gradio>=4.26.0
7
  langchain==0.2.16
8
  langchain-anthropic==0.1.23
9
  langchain-cohere==0.2.4
10
  langchain-community==0.2.17
11
+ langchain-core==0.2.41
12
  langchain-experimental==0.0.65
13
  langchain-nvidia-ai-endpoints==0.2.2
14
  langchain-ollama==0.1.3
15
  langchain-openai==0.1.25
16
  langchain-text-splitters==0.2.4
17
+ langchain-voyageai==0.1.1
18
  marqo==3.7.0
19
  nbformat==5.10.4
20
  openai==1.42.0
 
24
  requests==2.32.3
25
  semchunk==2.2.0
26
  sentence-transformers==3.1.0
27
+ tenacity==8.5.0
28
  tiktoken==0.7.0
29
  tokenizers==0.19.1
30
  transformers==4.44.2
31
  tree-sitter==0.22.3
32
  tree-sitter-language-pack==0.2.0
33
+ voyageai==0.2.3
sage/chat.py CHANGED
@@ -7,18 +7,15 @@ import logging
7
 
8
  import configargparse
9
  import gradio as gr
10
- import pkg_resources
11
  from dotenv import load_dotenv
12
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
13
  from langchain.chains.combine_documents import create_stuff_documents_chain
14
- from langchain.retrievers import ContextualCompressionRetriever
15
  from langchain.schema import AIMessage, HumanMessage
16
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
 
18
  import sage.config as sage_config
19
  from sage.llm import build_llm_via_langchain
20
- from sage.reranker import build_reranker
21
- from sage.vector_store import build_vector_store_from_args
22
 
23
  load_dotenv()
24
 
@@ -26,12 +23,7 @@ load_dotenv()
26
  def build_rag_chain(args):
27
  """Builds a RAG chain via LangChain."""
28
  llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
29
-
30
- retriever_top_k = 5 if args.reranker_provider == "none" else 25
31
- retriever = build_vector_store_from_args(args).as_retriever(top_k=retriever_top_k)
32
- compressor = build_reranker(args.reranker_provider, args.reranker_model)
33
- if compressor:
34
- retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
35
 
36
  # Prompt to contextualize the latest query based on the chat history.
37
  contextualize_q_system_prompt = (
@@ -83,6 +75,7 @@ def main():
83
 
84
  arg_validators = [
85
  sage_config.add_repo_args(parser),
 
86
  sage_config.add_vector_store_args(parser),
87
  sage_config.add_reranking_args(parser),
88
  sage_config.add_llm_args(parser),
 
7
 
8
  import configargparse
9
  import gradio as gr
 
10
  from dotenv import load_dotenv
11
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
12
  from langchain.chains.combine_documents import create_stuff_documents_chain
 
13
  from langchain.schema import AIMessage, HumanMessage
14
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
 
16
  import sage.config as sage_config
17
  from sage.llm import build_llm_via_langchain
18
+ from sage.retriever import build_retriever_from_args
 
19
 
20
  load_dotenv()
21
 
 
23
  def build_rag_chain(args):
24
  """Builds a RAG chain via LangChain."""
25
  llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
26
+ retriever = build_retriever_from_args(args)
 
 
 
 
 
27
 
28
  # Prompt to contextualize the latest query based on the chat history.
29
  contextualize_q_system_prompt = (
 
75
 
76
  arg_validators = [
77
  sage_config.add_repo_args(parser),
78
+ sage_config.add_embedding_args(parser),
79
  sage_config.add_vector_store_args(parser),
80
  sage_config.add_reranking_args(parser),
81
  sage_config.add_llm_args(parser),
sage/chunker.py CHANGED
@@ -299,6 +299,7 @@ class UniversalFileChunker(Chunker):
299
  """Chunks a file into smaller pieces, regardless of whether it's code or text."""
300
 
301
  def __init__(self, max_tokens: int):
 
302
  self.code_chunker = CodeFileChunker(max_tokens)
303
  self.ipynb_chunker = IpynbFileChunker(self.code_chunker)
304
  self.text_chunker = TextFileChunker(max_tokens)
 
299
  """Chunks a file into smaller pieces, regardless of whether it's code or text."""
300
 
301
  def __init__(self, max_tokens: int):
302
+ self.max_tokens = max_tokens
303
  self.code_chunker = CodeFileChunker(max_tokens)
304
  self.ipynb_chunker = IpynbFileChunker(self.code_chunker)
305
  self.text_chunker = TextFileChunker(max_tokens)
sage/config.py CHANGED
@@ -28,6 +28,25 @@ OPENAI_DEFAULT_EMBEDDING_SIZE = {
28
  "text-embedding-3-large": 3072,
29
  }
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def add_config_args(parser: ArgumentParser):
33
  """Adds configuration-related arguments to the parser."""
@@ -61,7 +80,7 @@ def add_repo_args(parser: ArgumentParser) -> Callable:
61
 
62
  def add_embedding_args(parser: ArgumentParser) -> Callable:
63
  """Adds embedding-related arguments to the parser and returns a validator."""
64
- parser.add("--embedding-provider", default="marqo", choices=["openai", "marqo"])
65
  parser.add(
66
  "--embedding-model",
67
  type=str,
@@ -115,11 +134,17 @@ def add_vector_store_args(parser: ArgumentParser) -> Callable:
115
  help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
116
  )
117
  parser.add(
118
- "--hybrid-retrieval",
119
- action=argparse.BooleanOptionalAction,
120
- default=True,
121
- help="Whether to use a hybrid of vector DB + BM25 retrieval. When set to False, we only use vector DB "
122
- "retrieval. This is only relevant if using Pinecone as the vector store.",
 
 
 
 
 
 
123
  )
124
  return validate_vector_store_args
125
 
@@ -169,6 +194,7 @@ def add_reranking_args(parser: ArgumentParser) -> Callable:
169
  help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
170
  "SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models",
171
  )
 
172
  # Trivial validator (nothing to check).
173
  return lambda _: True
174
 
@@ -228,6 +254,33 @@ def _validate_openai_embedding_args(args):
228
  raise ValueError(f"The maximum number of chunks per job is {OPENAI_MAX_TOKENS_PER_JOB}. Got {chunks_per_job}")
229
 
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  def _validate_marqo_embedding_args(args):
232
  """Validates the configuration of the Marqo batch embedder and sets defaults."""
233
  if not args.embedding_model:
@@ -247,6 +300,8 @@ def validate_embedding_args(args):
247
  """Validates the configuration of the batch embedder and sets defaults."""
248
  if args.embedding_provider == "openai":
249
  _validate_openai_embedding_args(args)
 
 
250
  elif args.embedding_provider == "marqo":
251
  _validate_marqo_embedding_args(args)
252
  else:
@@ -257,8 +312,11 @@ def validate_vector_store_args(args):
257
  """Validates the configuration of the vector store and sets defaults."""
258
 
259
  if not args.index_namespace:
 
 
 
260
  args.index_namespace = args.repo_id
261
- if args.commit_hash:
262
  args.index_namespace += "/" + args.commit_hash
263
  if args.vector_store_provider == "marqo":
264
  # Marqo doesn't allow slashes in the index namespace.
 
28
  "text-embedding-3-large": 3072,
29
  }
30
 
31
+ VOYAGE_MAX_CHUNKS_PER_BATCH = 128
32
+
33
+ def get_voyage_max_tokens_per_batch(model: str) -> int:
34
+ """Returns the maximum number of tokens per batch for the Voyage model.
35
+ See https://docs.voyageai.com/reference/embeddings-api."""
36
+ if model == "voyage-3-lite":
37
+ return 1_000_000
38
+ if model in ["voyage-3", "voyage-2"]:
39
+ return 320_000
40
+ return 120_000
41
+
42
+ def get_voyage_embedding_size(model: str) -> int:
43
+ """Returns the embedding size for the Voyage model. See https://docs.voyageai.com/docs/embeddings#model-choices."""
44
+ if model == "voyage-3-lite":
45
+ return 512
46
+ if model == "voyage-2-code":
47
+ return 1536
48
+ return 1024
49
+
50
 
51
  def add_config_args(parser: ArgumentParser):
52
  """Adds configuration-related arguments to the parser."""
 
80
 
81
  def add_embedding_args(parser: ArgumentParser) -> Callable:
82
  """Adds embedding-related arguments to the parser and returns a validator."""
83
+ parser.add("--embedding-provider", default="marqo", choices=["openai", "voyage", "marqo"])
84
  parser.add(
85
  "--embedding-model",
86
  type=str,
 
134
  help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
135
  )
136
  parser.add(
137
+ "--retrieval-alpha",
138
+ default=0.5,
139
+ type=float,
140
+ help="Takes effect for Pinecone retriever only. The weight of the dense (embeddings-based) vs sparse (BM25) "
141
+ "encoder in the final retrieval score. A value of 0.0 means BM25 only, 1.0 means embeddings only.",
142
+ )
143
+ parser.add(
144
+ "--retriever-top-k",
145
+ default=25,
146
+ type=int,
147
+ help="The number of top documents to retrieve from the vector store."
148
  )
149
  return validate_vector_store_args
150
 
 
194
  help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
195
  "SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models",
196
  )
197
+ parser.add("--reranker-top-k", default=5, help="The number of top documents to return after reranking.")
198
  # Trivial validator (nothing to check).
199
  return lambda _: True
200
 
 
254
  raise ValueError(f"The maximum number of chunks per job is {OPENAI_MAX_TOKENS_PER_JOB}. Got {chunks_per_job}")
255
 
256
 
257
+ def _validate_voyage_embedding_args(args):
258
+ """Validates the configuration of the Voyage batch embedder and sets defaults."""
259
+ if args.embedding_provider == "voyage" and not os.getenv("VOYAGE_API_KEY"):
260
+ raise ValueError("Please set the VOYAGE_API_KEY environment variable.")
261
+
262
+ if not args.embedding_model:
263
+ args.embedding_model = "voyage-code-2"
264
+
265
+ if not args.tokens_per_chunk:
266
+ # https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.
267
+ args.tokens_per_chunk = 800
268
+
269
+ if not args.chunks_per_batch:
270
+ args.chunks_per_batch = VOYAGE_MAX_CHUNKS_PER_BATCH
271
+ elif args.chunks_per_batch > VOYAGE_MAX_CHUNKS_PER_BATCH:
272
+ args.chunks_per_batch = VOYAGE_MAX_CHUNKS_PER_BATCH
273
+ logging.warning(f"Voyage enforces a limit of {VOYAGE_MAX_CHUNKS_PER_BATCH} chunks per batch. Overwriting.")
274
+
275
+ max_tokens = get_voyage_max_tokens_per_batch(args.embedding_model)
276
+ if args.tokens_per_chunk * args.chunks_per_batch > max_tokens:
277
+ raise ValueError(f"Voyage enforces a limit of {max_tokens} tokens per batch. "
278
+ "Reduce either --tokens-per-chunk or --chunks-per-batch.")
279
+
280
+ if not args.embedding_size:
281
+ args.embedding_size = get_voyage_embedding_size(args.embedding_model)
282
+
283
+
284
  def _validate_marqo_embedding_args(args):
285
  """Validates the configuration of the Marqo batch embedder and sets defaults."""
286
  if not args.embedding_model:
 
300
  """Validates the configuration of the batch embedder and sets defaults."""
301
  if args.embedding_provider == "openai":
302
  _validate_openai_embedding_args(args)
303
+ elif args.embedding_provider == "voyage":
304
+ _validate_voyage_embedding_args(args)
305
  elif args.embedding_provider == "marqo":
306
  _validate_marqo_embedding_args(args)
307
  else:
 
312
  """Validates the configuration of the vector store and sets defaults."""
313
 
314
  if not args.index_namespace:
315
+ # Attempt to derive a default index namespace from the repository information.
316
+ if "repo_id" not in args:
317
+ raise ValueError("Please set a value for --index-namespace.")
318
  args.index_namespace = args.repo_id
319
+ if "commit_hash" in args and args.commit_hash:
320
  args.index_namespace += "/" + args.commit_hash
321
  if args.vector_store_provider == "marqo":
322
  # Marqo doesn't allow slashes in the index namespace.
sage/configs/remote.yaml CHANGED
@@ -14,5 +14,4 @@ llm-provider: openai
14
  llm-model: gpt-4
15
 
16
  # Reranking
17
- reranking-provider: cohere
18
- reranking-model: rerank-english-v3.0
 
14
  llm-model: gpt-4
15
 
16
  # Reranking
17
+ reranker-provider: nvidia
 
sage/embedder.py CHANGED
@@ -9,7 +9,9 @@ from collections import Counter
9
  from typing import Dict, Generator, List, Optional, Tuple
10
 
11
  import marqo
 
12
  from openai import OpenAI
 
13
 
14
  from sage.chunker import Chunk, Chunker
15
  from sage.constants import TEXT_FIELD
@@ -205,6 +207,72 @@ class OpenAIBatchEmbedder(BatchEmbedder):
205
  }
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  class MarqoEmbedder(BatchEmbedder):
209
  """Embedder that uses the open-source Marqo vector search engine.
210
 
@@ -270,6 +338,8 @@ class MarqoEmbedder(BatchEmbedder):
270
  def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
271
  if args.embedding_provider == "openai":
272
  return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
 
 
273
  elif args.embedding_provider == "marqo":
274
  return MarqoEmbedder(
275
  data_manager, chunker, index_name=args.index_namespace, url=args.marqo_url, model=args.embedding_model
 
9
  from typing import Dict, Generator, List, Optional, Tuple
10
 
11
  import marqo
12
+ import requests
13
  from openai import OpenAI
14
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
15
 
16
  from sage.chunker import Chunk, Chunker
17
  from sage.constants import TEXT_FIELD
 
207
  }
208
 
209
 
210
+ class VoyageBatchEmbedder(BatchEmbedder):
211
+ """Batch embedder that calls Voyage. See https://docs.voyageai.com/reference/embeddings-api."""
212
+
213
+ def __init__(self, data_manager: DataManager, chunker: Chunker, embedding_model: str):
214
+ self.data_manager = data_manager
215
+ self.chunker = chunker
216
+ self.embedding_model = embedding_model
217
+ self.embedding_data = []
218
+
219
+ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
220
+ """Issues batch embedding jobs for the entire dataset."""
221
+ batch = []
222
+ chunk_count = 0
223
+
224
+ for content, metadata in self.data_manager.walk():
225
+ chunks = self.chunker.chunk(content, metadata)
226
+ chunk_count += len(chunks)
227
+ batch.extend(chunks)
228
+
229
+ token_count = chunk_count * self.chunker.max_tokens
230
+ if token_count % 900_000 == 0:
231
+ logging.info("Pausing for 60 seconds to avoid rate limiting...")
232
+ time.sleep(60) # Voyage API rate limits to 1m tokens per minute; we'll pause every 900k tokens.
233
+
234
+ if len(batch) > chunks_per_batch:
235
+ for i in range(0, len(batch), chunks_per_batch):
236
+ sub_batch = batch[i : i + chunks_per_batch]
237
+ logging.info("Embedding %d chunks...", len(sub_batch))
238
+ result = self._make_batch_request(sub_batch)
239
+ for chunk, datum in zip(sub_batch, result["data"]):
240
+ self.embedding_data.append((chunk.metadata, datum["embedding"]))
241
+ batch = []
242
+
243
+ # Finally, commit the last batch.
244
+ if batch:
245
+ logging.info("Embedding %d chunks...", len(batch))
246
+ result = self._make_batch_request(batch)
247
+ for chunk, datum in zip(batch, result["data"]):
248
+ self.embedding_data.append((chunk.metadata, datum["embedding"]))
249
+
250
+ logging.info(f"Successfully embedded {chunk_count} chunks.")
251
+
252
+ def embeddings_are_ready(self, *args, **kwargs) -> bool:
253
+ """Checks whether the batch embedding jobs are done."""
254
+ # The Voyage API is synchronous, so once embed_dataset() returns, the embeddings are ready.
255
+ return True
256
+
257
+ def download_embeddings(self, *args, **kwargs) -> Generator[Vector, None, None]:
258
+ """Yields (chunk_metadata, embedding) pairs for each chunk in the dataset."""
259
+ for chunk_metadata, embedding in self.embedding_data:
260
+ yield (chunk_metadata, embedding)
261
+
262
+ @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(6))
263
+ def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
264
+ """Makes a batch request to the Voyage API with exponential backoff when we hit rate limits."""
265
+ url = "https://api.voyageai.com/v1/embeddings"
266
+ headers = {"Authorization": f"Bearer {os.environ['VOYAGE_API_KEY']}", "Content-Type": "application/json"}
267
+ payload = {"input": [chunk.content for chunk in chunks], "model": self.embedding_model}
268
+
269
+ response = requests.post(url, json=payload, headers=headers)
270
+ if not response.status_code == 200:
271
+ raise ValueError(f"Failed to make batch request. Response: {response.text}")
272
+
273
+ return response.json()
274
+
275
+
276
  class MarqoEmbedder(BatchEmbedder):
277
  """Embedder that uses the open-source Marqo vector search engine.
278
 
 
338
  def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
339
  if args.embedding_provider == "openai":
340
  return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
341
+ elif args.embedding_provider == "voyage":
342
+ return VoyageBatchEmbedder(data_manager, chunker, args.embedding_model)
343
  elif args.embedding_provider == "marqo":
344
  return MarqoEmbedder(
345
  data_manager, chunker, index_name=args.index_namespace, url=args.marqo_url, model=args.embedding_model
sage/index.py CHANGED
@@ -90,7 +90,7 @@ def main():
90
  time.sleep(30)
91
 
92
  logging.info("Moving embeddings to the repo vector store...")
93
- repo_vector_store = build_vector_store_from_args(args)
94
  repo_vector_store.ensure_exists()
95
  repo_vector_store.upsert(repo_embedder.download_embeddings(repo_jobs_file))
96
 
@@ -101,7 +101,7 @@ def main():
101
  time.sleep(30)
102
 
103
  logging.info("Moving embeddings to the issues vector store...")
104
- issues_vector_store = build_vector_store_from_args(args)
105
  issues_vector_store.ensure_exists()
106
  issues_vector_store.upsert(issues_embedder.download_embeddings(issues_jobs_file))
107
 
 
90
  time.sleep(30)
91
 
92
  logging.info("Moving embeddings to the repo vector store...")
93
+ repo_vector_store = build_vector_store_from_args(args, repo_manager)
94
  repo_vector_store.ensure_exists()
95
  repo_vector_store.upsert(repo_embedder.download_embeddings(repo_jobs_file))
96
 
 
101
  time.sleep(30)
102
 
103
  logging.info("Moving embeddings to the issues vector store...")
104
+ issues_vector_store = build_vector_store_from_args(args, issues_manager)
105
  issues_vector_store.ensure_exists()
106
  issues_vector_store.upsert(issues_embedder.download_embeddings(issues_jobs_file))
107
 
sage/reranker.py CHANGED
@@ -8,6 +8,7 @@ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
8
  from langchain_community.document_compressors import JinaRerank
9
  from langchain_core.documents import BaseDocumentCompressor
10
  from langchain_nvidia_ai_endpoints import NVIDIARerank
 
11
 
12
 
13
  class RerankerProvider(Enum):
@@ -16,27 +17,33 @@ class RerankerProvider(Enum):
16
  COHERE = "cohere"
17
  NVIDIA = "nvidia"
18
  JINA = "jina"
 
19
 
20
 
21
- def build_reranker(provider: str, model: Optional[str] = None, top_n: Optional[int] = 5) -> BaseDocumentCompressor:
22
  if provider == RerankerProvider.NONE.value:
23
  return None
24
  if provider == RerankerProvider.HUGGINGFACE.value:
25
  model = model or "cross-encoder/ms-marco-MiniLM-L-6-v2"
26
  encoder_model = HuggingFaceCrossEncoder(model_name=model)
27
- return CrossEncoderReranker(model=encoder_model, top_n=top_n)
28
  if provider == RerankerProvider.COHERE.value:
29
  if not os.environ.get("COHERE_API_KEY"):
30
  raise ValueError("Please set the COHERE_API_KEY environment variable")
31
  model = model or "rerank-english-v3.0"
32
- return CohereRerank(model=model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=top_n)
33
  if provider == RerankerProvider.NVIDIA.value:
34
  if not os.environ.get("NVIDIA_API_KEY"):
35
  raise ValueError("Please set the NVIDIA_API_KEY environment variable")
36
  model = model or "nvidia/nv-rerankqa-mistral-4b-v3"
37
- return NVIDIARerank(model=model, api_key=os.environ.get("NVIDIA_API_KEY"), top_n=top_n, truncate="END")
38
  if provider == RerankerProvider.JINA.value:
39
  if not os.environ.get("JINA_API_KEY"):
40
  raise ValueError("Please set the JINA_API_KEY environment variable")
41
- return JinaRerank(top_n=top_n)
 
 
 
 
 
42
  raise ValueError(f"Invalid reranker provider: {provider}")
 
8
  from langchain_community.document_compressors import JinaRerank
9
  from langchain_core.documents import BaseDocumentCompressor
10
  from langchain_nvidia_ai_endpoints import NVIDIARerank
11
+ from langchain_voyageai import VoyageAIRerank
12
 
13
 
14
  class RerankerProvider(Enum):
 
17
  COHERE = "cohere"
18
  NVIDIA = "nvidia"
19
  JINA = "jina"
20
+ VOYAGE = "voyage"
21
 
22
 
23
+ def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor:
24
  if provider == RerankerProvider.NONE.value:
25
  return None
26
  if provider == RerankerProvider.HUGGINGFACE.value:
27
  model = model or "cross-encoder/ms-marco-MiniLM-L-6-v2"
28
  encoder_model = HuggingFaceCrossEncoder(model_name=model)
29
+ return CrossEncoderReranker(model=encoder_model, top_n=top_k)
30
  if provider == RerankerProvider.COHERE.value:
31
  if not os.environ.get("COHERE_API_KEY"):
32
  raise ValueError("Please set the COHERE_API_KEY environment variable")
33
  model = model or "rerank-english-v3.0"
34
+ return CohereRerank(model=model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=top_k)
35
  if provider == RerankerProvider.NVIDIA.value:
36
  if not os.environ.get("NVIDIA_API_KEY"):
37
  raise ValueError("Please set the NVIDIA_API_KEY environment variable")
38
  model = model or "nvidia/nv-rerankqa-mistral-4b-v3"
39
+ return NVIDIARerank(model=model, api_key=os.environ.get("NVIDIA_API_KEY"), top_n=top_k, truncate="END")
40
  if provider == RerankerProvider.JINA.value:
41
  if not os.environ.get("JINA_API_KEY"):
42
  raise ValueError("Please set the JINA_API_KEY environment variable")
43
+ return JinaRerank(top_n=top_k)
44
+ if provider == RerankerProvider.VOYAGE.value:
45
+ if not os.environ.get("VOYAGE_API_KEY"):
46
+ raise ValueError("Please set the VOYAGE_API_KEY environment variable")
47
+ model = model or "rerank-1"
48
+ return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k)
49
  raise ValueError(f"Invalid reranker provider: {provider}")
sage/retriever.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.retrievers import ContextualCompressionRetriever
2
+ from langchain_openai import OpenAIEmbeddings
3
+ from langchain_voyageai import VoyageAIEmbeddings
4
+
5
+
6
+ from sage.reranker import build_reranker
7
+ from sage.vector_store import build_vector_store_from_args
8
+
9
+
10
+ def build_retriever_from_args(args):
11
+ """Builds a retriever (with optional reranking) from command-line arguments."""
12
+
13
+ if args.embedding_provider == "openai":
14
+ embeddings = OpenAIEmbeddings(model=args.embedding_model)
15
+ elif args.embedding_provider == "voyage":
16
+ embeddings = VoyageAIEmbeddings(model=args.embedding_model)
17
+ else:
18
+ embeddings = None
19
+
20
+ retriever = build_vector_store_from_args(args).as_retriever(top_k=args.retriever_top_k, embeddings=embeddings)
21
+
22
+ reranker = build_reranker(args.reranker_provider, args.reranker_model, args.reranker_top_k)
23
+ if reranker:
24
+ retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
25
+ return retriever
sage/vector_store.py CHANGED
@@ -1,19 +1,22 @@
1
  """Vector store abstraction and implementations."""
2
 
 
 
3
  from abc import ABC, abstractmethod
4
  from functools import cached_property
5
- from typing import Dict, Generator, List, Tuple
6
 
7
  import marqo
8
  from langchain_community.retrievers import PineconeHybridSearchRetriever
9
  from langchain_community.vectorstores import Marqo
10
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
11
  from langchain_core.documents import Document
12
- from langchain_openai import OpenAIEmbeddings
13
  from pinecone import Pinecone, ServerlessSpec
14
  from pinecone_text.sparse import BM25Encoder
15
 
16
  from sage.constants import TEXT_FIELD
 
17
 
18
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
19
 
@@ -41,24 +44,40 @@ class VectorStore(ABC):
41
  self.upsert_batch(batch)
42
 
43
  @abstractmethod
44
- def as_retriever(self, top_k: int):
45
  """Converts the vector store to a LangChain retriever object."""
46
 
47
 
48
  class PineconeVectorStore(VectorStore):
49
  """Vector store implementation using Pinecone."""
50
 
51
- def __init__(self, index_name: str, namespace: str, dimension: int, hybrid: bool = True):
 
 
 
 
 
 
 
 
 
 
52
  self.index_name = index_name
53
  self.dimension = dimension
54
  self.client = Pinecone()
55
  self.namespace = namespace
56
- self.hybrid = hybrid
57
- # The default BM25 encoder was fit in the MS MARCO dataset.
58
- # See https://docs.pinecone.io/guides/data/encode-sparse-vectors
59
- # In the future, we should fit the encoder on the current dataset. It's somewhat non-trivial for large datasets,
60
- # because most BM25 implementations require the entire dataset to fit in memory.
61
- self.bm25_encoder = BM25Encoder.default() if hybrid else None
 
 
 
 
 
 
62
 
63
  @cached_property
64
  def index(self):
@@ -84,7 +103,7 @@ class PineconeVectorStore(VectorStore):
84
  name=self.index_name,
85
  dimension=self.dimension,
86
  # See https://www.pinecone.io/learn/hybrid-search-intro/
87
- metric="dotproduct" if self.hybrid else "cosine",
88
  spec=ServerlessSpec(cloud="aws", region="us-east-1"),
89
  )
90
 
@@ -98,19 +117,19 @@ class PineconeVectorStore(VectorStore):
98
 
99
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
100
 
101
- def as_retriever(self, top_k: int):
102
  if self.bm25_encoder:
103
  return PineconeHybridSearchRetriever(
104
- embeddings=OpenAIEmbeddings(),
105
  sparse_encoder=self.bm25_encoder,
106
  index=self.index,
107
  namespace=self.namespace,
108
  top_k=top_k,
109
- alpha=0.5,
110
  )
111
 
112
  return LangChainPinecone.from_existing_index(
113
- index_name=self.index_name, embedding=OpenAIEmbeddings(), namespace=self.namespace
114
  ).as_retriever(search_kwargs={"k": top_k})
115
 
116
 
@@ -128,7 +147,8 @@ class MarqoVectorStore(VectorStore):
128
  # Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
129
  pass
130
 
131
- def as_retriever(self, top_k: int):
 
132
  vectorstore = Marqo(client=self.client, index_name=self.index_name)
133
 
134
  # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
@@ -146,14 +166,32 @@ class MarqoVectorStore(VectorStore):
146
  return vectorstore.as_retriever(search_kwargs={"k": top_k})
147
 
148
 
149
- def build_vector_store_from_args(args: dict) -> VectorStore:
150
- """Builds a vector store from the given command-line arguments."""
 
 
 
 
151
  if args.vector_store_provider == "pinecone":
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  return PineconeVectorStore(
153
  index_name=args.pinecone_index_name,
154
  namespace=args.index_namespace,
155
  dimension=args.embedding_size if "embedding_size" in args else None,
156
- hybrid=args.hybrid_retrieval,
 
157
  )
158
  elif args.vector_store_provider == "marqo":
159
  return MarqoVectorStore(url=args.marqo_url, index_name=args.index_namespace)
 
1
  """Vector store abstraction and implementations."""
2
 
3
+ import os
4
+ import logging
5
  from abc import ABC, abstractmethod
6
  from functools import cached_property
7
+ from typing import Dict, Generator, List, Optional, Tuple
8
 
9
  import marqo
10
  from langchain_community.retrievers import PineconeHybridSearchRetriever
11
  from langchain_community.vectorstores import Marqo
12
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
13
  from langchain_core.documents import Document
14
+ from langchain_core.embeddings import Embeddings
15
  from pinecone import Pinecone, ServerlessSpec
16
  from pinecone_text.sparse import BM25Encoder
17
 
18
  from sage.constants import TEXT_FIELD
19
+ from sage.data_manager import DataManager
20
 
21
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
22
 
 
44
  self.upsert_batch(batch)
45
 
46
  @abstractmethod
47
+ def as_retriever(self, top_k: int, embeddings: Embeddings):
48
  """Converts the vector store to a LangChain retriever object."""
49
 
50
 
51
  class PineconeVectorStore(VectorStore):
52
  """Vector store implementation using Pinecone."""
53
 
54
+ def __init__(self, index_name: str, namespace: str, dimension: int, alpha: float, bm25_cache: Optional[str] = None):
55
+ """
56
+ Args:
57
+ index_name: The name of the Pinecone index to use. If it doesn't exist already, we'll create it.
58
+ namespace: The namespace within the index to use.
59
+ dimension: The dimension of the vectors.
60
+ alpha: The alpha parameter for hybrid search: alpha == 1.0 means pure dense search, alpha == 0.0 means pure
61
+ BM25, and 0.0 < alpha < 1.0 means a hybrid of the two.
62
+ bm25_cache: The path to the BM25 encoder file. If not specified, we'll use the default BM25 (fitted on the
63
+ MS MARCO dataset).
64
+ """
65
  self.index_name = index_name
66
  self.dimension = dimension
67
  self.client = Pinecone()
68
  self.namespace = namespace
69
+ self.alpha = alpha
70
+
71
+ if alpha < 1.0:
72
+ if bm25_cache and os.path.exists(bm25_cache):
73
+ logging.info("Loading BM25 encoder from cache.")
74
+ self.bm25_encoder = BM25Encoder()
75
+ self.bm25_encoder.load(path=bm25_cache)
76
+ else:
77
+ logging.info("Using default BM25 encoder (fitted to MS MARCO).")
78
+ self.bm25_encoder = BM25Encoder.default()
79
+ else:
80
+ self.bm25_encoder = None
81
 
82
  @cached_property
83
  def index(self):
 
103
  name=self.index_name,
104
  dimension=self.dimension,
105
  # See https://www.pinecone.io/learn/hybrid-search-intro/
106
+ metric="dotproduct" if self.bm25_encoder else "cosine",
107
  spec=ServerlessSpec(cloud="aws", region="us-east-1"),
108
  )
109
 
 
117
 
118
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
119
 
120
+ def as_retriever(self, top_k: int, embeddings: Embeddings):
121
  if self.bm25_encoder:
122
  return PineconeHybridSearchRetriever(
123
+ embeddings=embeddings,
124
  sparse_encoder=self.bm25_encoder,
125
  index=self.index,
126
  namespace=self.namespace,
127
  top_k=top_k,
128
+ alpha=self.alpha,
129
  )
130
 
131
  return LangChainPinecone.from_existing_index(
132
+ index_name=self.index_name, embedding=embeddings, namespace=self.namespace
133
  ).as_retriever(search_kwargs={"k": top_k})
134
 
135
 
 
147
  # Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
148
  pass
149
 
150
+ def as_retriever(self, top_k: int, embeddings: Embeddings = None):
151
+ del embeddings # Unused; The Marqo vector store is also an embedder.
152
  vectorstore = Marqo(client=self.client, index_name=self.index_name)
153
 
154
  # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
 
166
  return vectorstore.as_retriever(search_kwargs={"k": top_k})
167
 
168
 
169
+ def build_vector_store_from_args(args: dict, data_manager: Optional[DataManager] = None) -> VectorStore:
170
+ """Builds a vector store from the given command-line arguments.
171
+
172
+ When `data_manager` is specified and hybrid retrieval is requested, we'll use it to fit a BM25 encoder on the corpus
173
+ of documents.
174
+ """
175
  if args.vector_store_provider == "pinecone":
176
+ bm25_cache = os.path.join(".bm25_cache", args.index_namespace, "bm25_encoder.json")
177
+
178
+ if not os.path.exists(bm25_cache) and data_manager:
179
+ logging.info("Fitting BM25 encoder on the corpus...")
180
+ corpus = [content for content, _ in data_manager.walk()]
181
+ bm25_encoder = BM25Encoder()
182
+ bm25_encoder.fit(corpus)
183
+ # Make sure the folder exists, before we dump the encoder.
184
+ bm25_folder = os.path.dirname(bm25_cache)
185
+ if not os.path.exists(bm25_folder):
186
+ os.makedirs(bm25_folder)
187
+ bm25_encoder.dump(bm25_cache)
188
+
189
  return PineconeVectorStore(
190
  index_name=args.pinecone_index_name,
191
  namespace=args.index_namespace,
192
  dimension=args.embedding_size if "embedding_size" in args else None,
193
+ alpha=args.retrieval_alpha,
194
+ bm25_cache=bm25_cache,
195
  )
196
  elif args.vector_store_provider == "marqo":
197
  return MarqoVectorStore(url=args.marqo_url, index_name=args.index_namespace)