Spaces:
Running
Running
Retrieval benchmark (#39)
Browse files* Fixes for previous PR
* Add retriever.py
* Add retrieval benchmark
* Add --retrieval-alpha flag
* Fit BM25 to the current corpus and add Voyage embeddings
* Add benchmark README.
* Nits to retrieve.py
* Add Voyage reranker.
* Update README to reflect Voyage embeddings and reranker.
* Address reviewer comments
- README.md +12 -6
- benchmarks/retrieval/README.md +132 -0
- benchmarks/retrieval/assets/chunks.png +0 -0
- benchmarks/retrieval/assets/embeddings.png +0 -0
- benchmarks/retrieval/assets/markdown.png +0 -0
- benchmarks/retrieval/assets/rerankers.png +0 -0
- benchmarks/retrieval/assets/retrievers.png +0 -0
- benchmarks/retrieval/retrieve.py +108 -0
- benchmarks/retrieval/sample.json +177 -0
- requirements.txt +5 -1
- sage/chat.py +3 -10
- sage/chunker.py +1 -0
- sage/config.py +65 -7
- sage/configs/remote.yaml +1 -2
- sage/embedder.py +70 -0
- sage/index.py +2 -2
- sage/reranker.py +12 -5
- sage/retriever.py +25 -0
- sage/vector_store.py +57 -19
README.md
CHANGED
|
@@ -72,22 +72,28 @@ pip install git+https://github.com/Storia-AI/sage.git@main
|
|
| 72 |
<details>
|
| 73 |
<summary><strong>:cloud: Using external providers (higher quality)</strong></summary>
|
| 74 |
|
| 75 |
-
1.
|
| 76 |
|
| 77 |
```
|
| 78 |
-
export OPENAI_API_KEY=...
|
| 79 |
-
export
|
| 80 |
```
|
| 81 |
|
| 82 |
-
2.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
```
|
| 84 |
export PINECONE_INDEX_NAME=...
|
| 85 |
```
|
| 86 |
|
| 87 |
-
3. For reranking, we
|
| 88 |
```
|
| 89 |
-
export COHERE_API_KEY=... # or
|
| 90 |
export NVIDIA_API_KEY=... # or
|
|
|
|
|
|
|
| 91 |
export JINA_API_KEY=...
|
| 92 |
```
|
| 93 |
|
|
|
|
| 72 |
<details>
|
| 73 |
<summary><strong>:cloud: Using external providers (higher quality)</strong></summary>
|
| 74 |
|
| 75 |
+
1. For embeddings, we support <a href="https://platform.openai.com/docs/guides/embeddings">OpenAI</a> and <a href="https://docs.voyageai.com/docs/embeddings">Voyage</a>. According to [our experiments](benchmarks/retrieval/README.md), OpenAI is better quality. Their batch API is also faster, with more generous rate limits. Export the API key of the desired provider:
|
| 76 |
|
| 77 |
```
|
| 78 |
+
export OPENAI_API_KEY=... # or
|
| 79 |
+
export VOYAGE_API_KEY=...
|
| 80 |
```
|
| 81 |
|
| 82 |
+
2. We use <a href="https://www.pinecone.io/">Pinecone</a> for the vector store, so you will need an API key:
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
export PINECONE_API_KEY=...
|
| 86 |
+
```
|
| 87 |
+
If you want to reuse an existing Pinecone index, specify it. Otherwise we'll create a new one called `sage`.
|
| 88 |
```
|
| 89 |
export PINECONE_INDEX_NAME=...
|
| 90 |
```
|
| 91 |
|
| 92 |
+
3. For reranking, we support <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a>, <a href="https://docs.voyageai.com/docs/reranker">Voyage</a>, <a href="https://cohere.com/rerank">Cohere</a>, and <a href="https://jina.ai/reranker/">Jina</a>. According to [our experiments](benchmark/retrieval/README.md), NVIDIA performs best. Export the API key of the desired provider:
|
| 93 |
```
|
|
|
|
| 94 |
export NVIDIA_API_KEY=... # or
|
| 95 |
+
export VOYAGE_API_KEY=... # or
|
| 96 |
+
export COHERE_API_KEY=... # or
|
| 97 |
export JINA_API_KEY=...
|
| 98 |
```
|
| 99 |
|
benchmarks/retrieval/README.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Chat-with-your-codebase: Retrieval Benchmark
|
| 2 |
+
When using this repository (which allows you to chat with your codebase in two commands), you are indirectly making a series of choices that greatly influence the quality of your AI copilot: chunking strategy, embeddings, retrieval algorithm, rerankers, etc.
|
| 3 |
+
|
| 4 |
+
Our role as maintainers is two-fold: to give you options/flexibility, but also to find good defaults. We're not here just to dump code on the Internet. We're here to *make it work*.
|
| 5 |
+
|
| 6 |
+
To make progress, we need a ladder to climb. That's why we partnered with our friends at [Morph Labs](https://morph.so) to produce a benchmark that will allow us to make informed decisions and measure progress. We will make it public soon, but if you really really can't wait, let us know at [founders@storia.ai](mailto:founders@storia.ai).
|
| 7 |
+
|
| 8 |
+
Here you will find our first learnings enabled by this dataset. We focused on proprietary APIs, but we're planning on extending experiments to open-source models as well.
|
| 9 |
+
|
| 10 |
+
#### TL;DR
|
| 11 |
+
- OpenAI's `text-embedding-3-small` embeddings perform best.
|
| 12 |
+
- NVIDIA's reranker outperforms Cohere, Voyage and Jina.
|
| 13 |
+
- Sparse retrieval (e.g. BM25) is actively hurting code retrieval if you have natural language files in your index (e.g. Markdown).
|
| 14 |
+
- Chunks of size 800 are ideal; going smaller has very marginal gains.
|
| 15 |
+
- Going beyond `top_k=25` for retrieval has diminishing returns.
|
| 16 |
+
|
| 17 |
+
And now, if you want to nerd out, here's a bunch of plots and stats.
|
| 18 |
+
|
| 19 |
+
## Dataset
|
| 20 |
+
Our dataset consists of 1,000 `<question, answer, relevant_documents>` pairs that focus on Hugging Face's [Transformers](https://github.com/huggingface/transformers) library.
|
| 21 |
+
|
| 22 |
+
The dataset was generated artificially and checked for quality by humans (we collaborated with [Morph Labs](https://morph.so)). The questions were designed to require context from 1-3 different Python files in order to be answered correctly.
|
| 23 |
+
|
| 24 |
+
A sample of 10 instances is provided in [sample.json](sample.json).
|
| 25 |
+
|
| 26 |
+
### Code Retrieval Benchmark
|
| 27 |
+
Here, we will be using `<question, relevant_documents>` pairs as a code retrieval benchmark. For instance:
|
| 28 |
+
```
|
| 29 |
+
- Question:
|
| 30 |
+
When developing a specialized model class in the Transformers library, how does `auto_class_update` ensure that the new class's methods are tailored specifically for its requirements while preserving the functionality of the original methods from the base class?
|
| 31 |
+
|
| 32 |
+
- Relevant documents:
|
| 33 |
+
huggingface/transformers/src/transformers/models/auto/auto_factory.py
|
| 34 |
+
huggingface/transformers/src/transformers/utils/doc.py
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
#### Why not use an already-established code retrieval benchmark?
|
| 38 |
+
Indeed, there are already comprehensive code retrieval benchmarks like [CoIR](https://arxiv.org/abs/2407.02883). In fact, the [CosQA](https://arxiv.org/abs/2105.13239) subset of this benchmark has a similar format to ours (text-to-code retrieval for web queries).
|
| 39 |
+
|
| 40 |
+
However, we designed our document space to be *an entire codebase*, as opposed to a set of isolated Python functions. A real-world codebase contains a variety of files, including ones that are distracting and get undeservedly selected by the retriever. For instance, dense retrievers tend to prefer short files. READMEs also tend to score high even when irrelevant, since they're written in natural language. Our benchmark is able to surface such behaviors. It also allows us to experiment with a variety of strategies like file chunking.
|
| 41 |
+
|
| 42 |
+
In the rest of this document, we'll be sharing a few initial learnings enabled by our benchmark.
|
| 43 |
+
|
| 44 |
+
### Metrics
|
| 45 |
+
|
| 46 |
+
Throughout this report, we will use the following evaluation metrics, as implemented by the [ir-measures](https://ir-measur.es/en/latest/) library.
|
| 47 |
+
- [R-Precision](https://ir-measur.es/en/latest/measures.html#rprec): The precision at R, where R is the number of relevant documents for a given query. Since our queries have a variable number of relevant documents (1-3), this is a convenient metric.
|
| 48 |
+
- [Precision@1 (P@1)](https://ir-measur.es/en/latest/measures.html#p): Reflects how many of the documents retrieved on the first position are actually golden documents. Note that P@3 would be a misleading metric: since not all queries have 3 relevant documents, not even the golden dataset would score 100%.
|
| 49 |
+
- [Recall@3 (R@3)](https://ir-measur.es/en/latest/measures.html#r): Reflects how many of the golden documents were retrieved by the system. Note that R@1 would be a misleading metric: since a query can have multiple equally-relevant documents, not even the golden dataset would score 100%.
|
| 50 |
+
- [Mean Reciprocal Rank (MRR)](https://ir-measur.es/en/latest/measures.html#rr): For each query, takes the first golden document and looks up its rank in the retrieved documents. For instance, if the first golden document is retrieved second, the score for this query is 1/2. Note this metric is somewhat incomplete for our benchmark, because we might have multiple relevant documents.
|
| 51 |
+
|
| 52 |
+
## Embeddings
|
| 53 |
+
:classical_building: **Verdict**: Use OpenAI's `text-embedding-3-small` embeddings.
|
| 54 |
+
|
| 55 |
+
Today, most retrieval systems are *dense*. They pre-compute document *embeddings* and store them in an index. At inference time, queries are also mapped to the same embedding space. In this world, retrieval is equivalent to finding the nearest neighbors of the query embedding in the index.
|
| 56 |
+
|
| 57 |
+
To this end, the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard) (Massive Text Embeddings Benchmark) offers a comprehensive comparison for open-source embeddings.
|
| 58 |
+
|
| 59 |
+
To complement this, we compared proprietary embedding APIs from [OpenAI](https://platform.openai.com/docs/guides/embeddings) and [Voyage](https://docs.voyageai.com/docs/embeddings). The main advantage of using these providers (in addition to quality) is that they provide *batch* embedding APIs, so you can get an entire repository indexed relatively quickly without the headache of hosting your own embedding models (you can do so with a simple `sage-index $GITHUB_REPO` command).
|
| 60 |
+
|
| 61 |
+

|
| 62 |
+
|
| 63 |
+
The plot above shows the performance of the three types of embeddings from OpenAI (`text-embedding-3-small`, `text-embedding-3-large`, `text-embedding-ada-002`) and the code-specific embeddings from Voyage (`voyage-code-2`).
|
| 64 |
+
|
| 65 |
+
#### Experiment settings
|
| 66 |
+
|
| 67 |
+
- File chunks of <= 800 tokens;
|
| 68 |
+
- Dense retriever (nearest neighbor according to cosine distance of embeddings);
|
| 69 |
+
- Retrieved `top_k=25`;
|
| 70 |
+
- Reranked documents using the [NVIDIA re-ranker](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html) and selected `top_k=3`.
|
| 71 |
+
|
| 72 |
+
#### Results
|
| 73 |
+
|
| 74 |
+
- Across most evaluation metrics, OpenAI's `text-embedding-3-small` performs best.
|
| 75 |
+
- It's remarkable that the `text-embedding-3-large` embeddings don't perform better, despite having double the size (3072 vs 1536).
|
| 76 |
+
- The older `text-embedding-ada-002` embeddings are trailing last with a huge gap in performance, so this is your call to update your pipeline if you haven't already.
|
| 77 |
+
|
| 78 |
+
## Rerankers
|
| 79 |
+
:classical_building: **Verdict**: Use NVIDIA's reranker.
|
| 80 |
+
|
| 81 |
+
In a world with infinitely fast compute, we would perform retrieval by passing each `<query, document>` pair through a Transformer, allowing all the query tokens to attend to all the document tokens. However, this is prohibitively expensive.
|
| 82 |
+
|
| 83 |
+
In practice, all documents are embedded independently and stored in a vector database. Most retrieval systems are two-staged: (1) embed the query independently to find its top N nearest neighbor documents, and (2) re-encode all top N `<query, document>` pairs and select the top K scoring ones. The second stage is called *reranking*.
|
| 84 |
+
|
| 85 |
+

|
| 86 |
+
|
| 87 |
+
While the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard) compares *open-source* embedding models based on their ability to rerank documents, we conducted experiments on the most popular *proprietary* APIs for reranking, including [NVIDIA](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html), [Voyage](https://docs.voyageai.com/docs/reranker), [Cohere](https://cohere.com/rerank) and [Jina](https://jina.ai/reranker/).
|
| 88 |
+
|
| 89 |
+
#### Experiment settings
|
| 90 |
+
- File chunks of <= 800 tokens;
|
| 91 |
+
- Dense retriever using OpenAI's `text-embedding-3-small` model;
|
| 92 |
+
- Retrieved `top_k=25` documents;
|
| 93 |
+
- Reranked documents and selected `top_k=3`.
|
| 94 |
+
|
| 95 |
+
#### Results
|
| 96 |
+
- Across all evaluation metrics, the highest performing rerankers are, in this order: [NVIDIA](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html), [Voyage](https://docs.voyageai.com/docs/reranker), [Cohere](https://cohere.com/rerank) and [Jina](https://jina.ai/reranker/).
|
| 97 |
+
- Not using a reranker at all completely tanks the performance.
|
| 98 |
+
|
| 99 |
+
## Retrieval: Sparse vs Dense
|
| 100 |
+
:classical_building: **Verdict**: Use fully dense embeddings.
|
| 101 |
+
|
| 102 |
+
So far, we've been experimenting with purely *dense* retrieval. That is, documents are selected solely on the cosine distance between their embedding and the query embedding.
|
| 103 |
+
|
| 104 |
+
Before the emergence of deep learning, retrievers used to be *sparse*. Such retrievers (e.g. [TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) or [BM25](https://en.wikipedia.org/wiki/Okapi_BM25)) were based on vectors of word counts (the vector of a document has the length of the dictionary, with each entry showing how many times a token occurs in the document; the term *sparse* comes from the fact that most entries are 0).
|
| 105 |
+
|
| 106 |
+
Since sparse retrievers rely on exact string match, one might assume they come in handy when the query contains a relatively unique token (e.g. a class name) that occurs in a small number of documents.
|
| 107 |
+
|
| 108 |
+
At the intersection of dense and sparse retrievers, *hybrid* retrievers score documents by the weighted average of the dense and sparse scores.
|
| 109 |
+
|
| 110 |
+

|
| 111 |
+
|
| 112 |
+
In the experiment above, we compared the three types of retrievers (dense, hybrid and sparse).
|
| 113 |
+
|
| 114 |
+
#### Experiment settings
|
| 115 |
+
- File chunks of <= 800 tokens;
|
| 116 |
+
- For the dense and hybrid retrievers, we used OpenAI's `text-embedding-3-small` model for embeddings;
|
| 117 |
+
- Retrieved `top_k=25` documents;
|
| 118 |
+
- Reranked documents using the [NVIDIA re-ranker](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html) and selected `top_k=3`.
|
| 119 |
+
|
| 120 |
+
#### Results
|
| 121 |
+
Somewhat surprisingly, sparse retrieval is actively hurting performance. The reason is that exact string matching will favor files that are in natural language (and therefore match the token distribution in the query).
|
| 122 |
+
|
| 123 |
+
The plot below shows what percentage of the retrieved files are in Markdown. The purely sparse retriever chooses a Markdown file 40% of the time! Remember that we designed our questions so that the required context are Python files. This doesn't preclude Markdown files from actually being helpful in answering some of the questions, but surely not to this degree.
|
| 124 |
+
|
| 125 |
+

|
| 126 |
+
|
| 127 |
+
## Chunk sizes
|
| 128 |
+
:classical_building: **Verdict**: 800 tokens per chunk works well
|
| 129 |
+
|
| 130 |
+
The [CodeRag paper](https://arxiv.org/pdf/2406.14497) suggests that the ideal chunk size is somewhere between 200-800 tokens. All our experiments above used 800 tokens per chunk. When experimenting with the other end of the spectrum, we saw very mild improvements from having smaller chunks. We believe that these marginal gains are not worth the increased indexing time (since we need to send 4x more queries to the batch embedding APIs).
|
| 131 |
+
|
| 132 |
+

|
benchmarks/retrieval/assets/chunks.png
ADDED
|
benchmarks/retrieval/assets/embeddings.png
ADDED
|
benchmarks/retrieval/assets/markdown.png
ADDED
|
benchmarks/retrieval/assets/rerankers.png
ADDED
|
benchmarks/retrieval/assets/retrievers.png
ADDED
|
benchmarks/retrieval/retrieve.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Script to call retrieval on a benchmark dataset.
|
| 2 |
+
|
| 3 |
+
Make sure to `pip install ir_measures` before running this script.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
import configargparse
|
| 12 |
+
from ir_measures import MAP, MRR, P, Qrel, R, Rprec, ScoredDoc, calc_aggregate, nDCG
|
| 13 |
+
|
| 14 |
+
import sage.config
|
| 15 |
+
from sage.retriever import build_retriever_from_args
|
| 16 |
+
|
| 17 |
+
logging.basicConfig(level=logging.INFO)
|
| 18 |
+
logger = logging.getLogger()
|
| 19 |
+
logger.setLevel(logging.INFO)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
parser = configargparse.ArgParser(
|
| 24 |
+
description="Runs retrieval on a benchmark dataset.", ignore_unknown_config_file_keys=True
|
| 25 |
+
)
|
| 26 |
+
parser.add("--benchmark", required=True, help="Path to the benchmark dataset.")
|
| 27 |
+
parser.add(
|
| 28 |
+
"--gold-field", default="context_files", help="Field in the benchmark dataset that contains the golden answers."
|
| 29 |
+
)
|
| 30 |
+
parser.add(
|
| 31 |
+
"--question-field", default="question", help="Field in the benchmark dataset that contains the questions."
|
| 32 |
+
)
|
| 33 |
+
parser.add(
|
| 34 |
+
"--logs-dir",
|
| 35 |
+
default=None,
|
| 36 |
+
help="Path where to output predictions and metrics. Optional, since metrics are also printed to console."
|
| 37 |
+
)
|
| 38 |
+
parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
|
| 39 |
+
|
| 40 |
+
sage.config.add_config_args(parser)
|
| 41 |
+
sage.config.add_embedding_args(parser)
|
| 42 |
+
sage.config.add_vector_store_args(parser)
|
| 43 |
+
sage.config.add_reranking_args(parser)
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
sage.config.validate_vector_store_args(args)
|
| 46 |
+
|
| 47 |
+
retriever = build_retriever_from_args(args)
|
| 48 |
+
|
| 49 |
+
with open(args.benchmark, "r") as f:
|
| 50 |
+
benchmark = json.load(f)
|
| 51 |
+
if args.max_instances is not None:
|
| 52 |
+
benchmark = benchmark[: args.max_instances]
|
| 53 |
+
|
| 54 |
+
golden_docs = [] # List of ir_measures.Qrel objects
|
| 55 |
+
retrieved_docs = [] # List of ir_measures.ScoredDoc objects
|
| 56 |
+
|
| 57 |
+
for question_idx, item in enumerate(benchmark):
|
| 58 |
+
print(f"Processing question {question_idx}...")
|
| 59 |
+
|
| 60 |
+
query_id = str(question_idx) # Solely needed for ir_measures library.
|
| 61 |
+
|
| 62 |
+
for golden_filepath in item[args.gold_field]:
|
| 63 |
+
# All the file paths in the golden answer are equally relevant for the query (i.e. the order is irrelevant),
|
| 64 |
+
# so we set relevance=1 for all of them.
|
| 65 |
+
golden_docs.append(Qrel(query_id=query_id, doc_id=golden_filepath, relevance=1))
|
| 66 |
+
|
| 67 |
+
# Make a retrieval call for the current question.
|
| 68 |
+
retrieved = retriever.invoke(item[args.question_field])
|
| 69 |
+
item["retrieved"] = []
|
| 70 |
+
for doc_idx, doc in enumerate(retrieved):
|
| 71 |
+
# The absolute value of the scores below does not affect the metrics; it merely determines the ranking of
|
| 72 |
+
# the retrived documents. The key of the score varies depending on the underlying retriever. If there's no
|
| 73 |
+
# score, we use 1/(doc_idx+1) since it preserves the order of the documents.
|
| 74 |
+
score = doc.metadata.get("score", doc.metadata.get("relevance_score", 1 / (doc_idx + 1)))
|
| 75 |
+
retrieved_docs.append(
|
| 76 |
+
ScoredDoc(query_id=query_id, doc_id=doc.metadata["file_path"], score=score)
|
| 77 |
+
)
|
| 78 |
+
# Update the output dictionary with the retrieved documents.
|
| 79 |
+
item["retrieved"].append({"file_path": doc.metadata["file_path"], "score": score})
|
| 80 |
+
|
| 81 |
+
if "answer" in item:
|
| 82 |
+
item.pop("answer") # Makes the output file harder to read.
|
| 83 |
+
|
| 84 |
+
print("Calculating metrics...")
|
| 85 |
+
results = calc_aggregate([Rprec, P @ 1, R @ 3, nDCG @ 3, MAP, MRR], golden_docs, retrieved_docs)
|
| 86 |
+
results = {str(key): value for key, value in results.items()}
|
| 87 |
+
|
| 88 |
+
if args.logs_dir:
|
| 89 |
+
if not os.path.exists(args.logs_dir):
|
| 90 |
+
os.makedirs(args.logs_dir)
|
| 91 |
+
|
| 92 |
+
out_data = {
|
| 93 |
+
"data": benchmark,
|
| 94 |
+
"metrics": results,
|
| 95 |
+
"flags": vars(args), # For reproducibility.
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
output_file = os.path.join(args.logs_dir, f"{time.time()}.json")
|
| 99 |
+
with open(output_file, "w") as f:
|
| 100 |
+
json.dump(out_data, f, indent=4)
|
| 101 |
+
|
| 102 |
+
for key in sorted(results.keys()):
|
| 103 |
+
print(f"{key}: {results[key]}")
|
| 104 |
+
print(f"Predictions and metrics saved to {output_file}")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
main()
|
benchmarks/retrieval/sample.json
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"repo": "huggingface/transformers",
|
| 4 |
+
"commit": "7bb1c99800d235791dace10305731f377db8077b",
|
| 5 |
+
"context_files": [
|
| 6 |
+
"huggingface/transformers/src/transformers/commands/serving.py",
|
| 7 |
+
"huggingface/transformers/src/transformers/pipelines/__init__.py"
|
| 8 |
+
],
|
| 9 |
+
"question": "With the introduction of a new translation service for \"en_to_es\", how does `serve_command_factory` ensure the server is prepared to handle this specific task efficiently?",
|
| 10 |
+
"answer": "The `serve_command_factory` function is designed to dynamically configure and deploy a server that can handle a variety of tasks, including complex and parameterized tasks such as language-specific translations. When a new translation service for \"en_to_es\" is introduced, the function ensures efficient handling through several key steps and mechanisms:\n\n1. **Task Specification and Validation**: The function begins by accepting command line arguments (`args`) which specify the task among other parameters. The task, in this case \"translation_en_to_es\", is passed to the `pipeline` function. This function internally calls `check_task` to validate and process the task string. The `check_task` function checks if the task is recognized and supported, and it returns structured data that includes default configurations and any task-specific options necessary for the task.\n\n2. **Dynamic Pipeline Configuration**: The `pipeline` function uses the information returned by `check_task` to dynamically configure a pipeline suitable for the specified task. This includes loading the appropriate model, tokenizer, and any additional components like feature extractors or image processors that are required for the task. For translation tasks, this might involve loading a specific model that is optimized for the language pair \"en_to_es\".\n\n3. **Model and Tokenizer Loading**: Depending on the task and the provided arguments, the `pipeline` function loads the necessary model and tokenizer. If the model or tokenizer is not explicitly specified in the command line arguments, the function will load default models and tokenizers that are appropriate for the task. This ensures that the server is using the most suitable tools for the translation between English and Spanish.\n\n4. **Device and Performance Configuration**: The `pipeline` function also handles device allocation (e.g., CPU, GPU) based on the provided arguments, ensuring that the model runs on the optimal hardware. This is crucial for performance, especially when handling complex tasks like real-time translation.\n\n5. **Server Initialization**: Once the pipeline is configured, `serve_command_factory` initializes a `ServeCommand` with the configured pipeline and server parameters such as host, port, and number of workers. This command encapsulates all the necessary settings to run the server efficiently.\n\n6. **Adaptability and Scalability**: The design of `serve_command_factory` and the underlying `pipeline` function allows for easy adaptation to new tasks and changes in task requirements. Adding support for a new translation like \"en_to_es\" primarily involves ensuring that the `check_task` function recognizes the task and that appropriate models and tokenizers are available and correctly configured.\n\nIn summary, `serve_command_factory` ensures that the server is prepared to handle the \"en_to_es\" translation task efficiently by dynamically configuring the server based on the specified task, validating and processing the task to load appropriate models and tools, and setting up the server with optimal hardware and network settings. This process guarantees that the server can adapt to new and complex tasks with minimal manual intervention.",
|
| 11 |
+
"expected_identifiers": [
|
| 12 |
+
"args",
|
| 13 |
+
"check_task",
|
| 14 |
+
"pipeline",
|
| 15 |
+
"ServeCommand"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"repo": "huggingface/transformers",
|
| 20 |
+
"commit": "7bb1c99800d235791dace10305731f377db8077b",
|
| 21 |
+
"context_files": [
|
| 22 |
+
"huggingface/transformers/src/transformers/models/esm/modeling_esmfold.py",
|
| 23 |
+
"huggingface/transformers/src/transformers/models/esm/openfold_utils/feats.py",
|
| 24 |
+
"huggingface/transformers/src/transformers/models/esm/openfold_utils/tensor_utils.py"
|
| 25 |
+
],
|
| 26 |
+
"question": "In a high-throughput setting where multiple protein structures are processed simultaneously, how does `EsmForProteinFolding.output_to_pdb` ensure accurate and independent structural representation in the resulting PDB files?",
|
| 27 |
+
"answer": "In a high-throughput setting where multiple protein structures are processed simultaneously, the function `output_to_pdb` ensures accurate and independent structural representation in the resulting PDB files through a combination of specialized tensor operations and careful indexing. This is achieved primarily through the use of the `atom14_to_atom37` function, which itself relies on the `batched_gather` function to correctly map atom positions from a simplified model output to a more detailed atomic representation.\n\n### Detailed Workflow:\n\n1. **Batch Processing and Tensor Operations**:\n - The `output_to_pdb` function begins by converting all tensor data to the CPU and converting them to NumPy arrays for easier manipulation. This step is crucial for performance and compatibility with subsequent operations that may not be optimized for GPU tensors.\n\n2. **Mapping Atom Positions**:\n - The function `atom14_to_atom37` is called within `output_to_pdb`. This function is responsible for expanding the reduced atom representation (14 atoms per amino acid) to a fuller representation (37 atoms per amino acid). It uses the `batched_gather` function to achieve this mapping accurately across potentially multiple proteins in a batch.\n\n3. **Complex Indexing with `batched_gather`**:\n - `batched_gather` plays a critical role in ensuring that the atom positions are mapped correctly. It constructs a complex indexing tuple that combines batch indices with the provided indices for gathering (`inds`). This tuple (`ranges`) includes both batch dimensions and the specific indices where atoms need to be gathered from the `atom14` tensor.\n - The use of `ranges` in `batched_gather` ensures that each protein's data is handled independently, preventing any cross-contamination or mixing of data between different proteins in the batch. This is crucial for maintaining the structural integrity of each protein.\n\n4. **Application of Mask and Final Adjustments**:\n - After mapping the positions, `atom14_to_atom37` applies a mask (`batch[\"atom37_atom_exists\"]`) to ensure that only existing atoms are considered. This step further ensures the accuracy of the structural data by zeroing out positions of non-existent atoms, preventing any erroneous data from affecting the structural representation.\n\n5. **Generation of PDB Data**:\n - Back in `output_to_pdb`, for each protein in the batch, an instance of `OFProtein` is created with the mapped atom positions, types, and other relevant data. The `to_pdb` function is then used to convert these protein data into the PDB format, ready for downstream applications like molecular dynamics simulations.\n\n### Conclusion:\n\nThrough the careful use of tensor operations, complex indexing, and data masking, `output_to_pdb` ensures that each protein's structural data is accurately and independently represented in the PDB outputs. This methodical approach is essential in high-throughput settings, where the accuracy and integrity of structural data are paramount for subsequent scientific analysis and applications.",
|
| 28 |
+
"expected_identifiers": [
|
| 29 |
+
"atom14_to_atom37",
|
| 30 |
+
"batched_gather",
|
| 31 |
+
"batch[\"atom37_atom_exists\"]",
|
| 32 |
+
"OFProtein"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"repo": "huggingface/transformers",
|
| 37 |
+
"commit": "7bb1c99800d235791dace10305731f377db8077b",
|
| 38 |
+
"context_files": [
|
| 39 |
+
"huggingface/transformers/src/transformers/models/auto/auto_factory.py",
|
| 40 |
+
"huggingface/transformers/src/transformers/dynamic_module_utils.py"
|
| 41 |
+
],
|
| 42 |
+
"question": "Following a security update in the production environment that limits internet connectivity, how does `_BaseAutoModelClass.from_pretrained` guarantee that the loaded model adheres strictly to the predefined version and settings?",
|
| 43 |
+
"answer": "In the updated production environment with restricted internet connectivity, `_BaseAutoModelClass.from_pretrained` ensures that the model loaded adheres strictly to the predefined version and settings through several key mechanisms, primarily involving the management of model files and code via a version control system and secure access to private repositories.\n\n### Version Control and Revision Specification\n\nThe function leverages a version control system that allows users to specify exact revisions of the model or code they wish to use. This is evident in the handling of the `revision` parameter in functions like `get_cached_module_file` and `get_class_from_dynamic_module`. The `revision` parameter can accept any identifier allowed by git, such as a branch name, a tag name, or a commit id. This ensures that the exact version of the model or code that was tested and approved in other environments (like development or staging) is the same version being deployed in production.\n\nFor example, in the `get_cached_module_file` function, the `revision` parameter is used to fetch the specific version of a module file from a repository:\n```python\nresolved_module_file = cached_file(\n pretrained_model_name_or_path,\n module_file,\n cache_dir=cache_dir,\n force_download=force_download,\n proxies=proxies,\n resume_download=resume_download,\n local_files_only=local_files_only,\n token=token,\n revision=revision,\n repo_type=repo_type,\n _commit_hash=_commit_hash,\n)\n```\n\n### Secure Access to Private Repositories\n\nThe function can authenticate access to private repositories using tokens, which is crucial when operating in environments with strict security protocols. The `token` parameter, which can be set to a string or `True` (to use the token generated by `huggingface-cli login`), is used to authenticate HTTP requests for remote files. This is handled securely in both `get_cached_module_file` and `get_class_from_dynamic_module`, ensuring that only authorized users can access private model files or code.\n\nFor instance, in `get_class_from_dynamic_module`, the `token` parameter is used to authenticate and download the necessary module file:\n```python\nfinal_module = get_cached_module_file(\n repo_id,\n module_file + \".py\",\n cache_dir=cache_dir,\n force_download=force_download,\n resume_download=resume_download,\n proxies=proxies,\n token=token,\n revision=code_revision,\n local_files_only=local_files_only,\n repo_type=repo_type,\n)\n```\n\n### Handling Restricted Internet Connectivity\n\nIn environments with limited internet access, the `local_files_only` parameter becomes particularly important. This parameter, when set to `True`, forces the function to only look for model files locally and not attempt to download them from the internet. This is crucial for ensuring that the model loading process does not fail due to lack of internet access and adheres to strict security protocols that might block external internet connections.\n\n### Conclusion\n\nBy utilizing these mechanisms, `_BaseAutoModelClass.from_pretrained` ensures that the model loaded in a production environment with restricted internet access is exactly the version specified, using secure and authenticated access where necessary. This approach guarantees consistency, reproducibility, and adherence to security protocols across different environments.",
|
| 44 |
+
"expected_identifiers": [
|
| 45 |
+
"revision",
|
| 46 |
+
"token",
|
| 47 |
+
"local_files_only"
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"repo": "huggingface/transformers",
|
| 52 |
+
"commit": "7bb1c99800d235791dace10305731f377db8077b",
|
| 53 |
+
"context_files": [
|
| 54 |
+
"huggingface/transformers/src/transformers/models/auto/auto_factory.py",
|
| 55 |
+
"huggingface/transformers/src/transformers/utils/doc.py"
|
| 56 |
+
],
|
| 57 |
+
"question": "When developing a specialized model class in the Transformers library, how does `auto_class_update` ensure that the new class's methods are tailored specifically for its requirements while preserving the functionality of the original methods from the base class?",
|
| 58 |
+
"answer": "In the Transformers library, the `auto_class_update` function plays a crucial role in dynamically creating specialized model classes that inherit functionalities from a base class but also have unique customizations. This is particularly important when different model classes need specific configurations or preprocessing steps that are not shared across all models.\n\nThe core mechanism that allows `auto_class_update` to achieve this functionality without altering the behavior of the base class methods lies in its use of the `copy_func` function. Here's how it works step-by-step:\n\n1. **Copying the Function**: `copy_func` is used to create an exact copy of the methods `from_config` and `from_pretrained` from the base class `_BaseAutoModelClass`. This is done by duplicating the `__code__` object of these methods. The `__code__` object contains the compiled executable code that the Python interpreter runs. By copying this code object, the new function retains the exact behavior and logic of the original function.\n\n2. **Customization of the Copied Function**: After copying, `auto_class_update` modifies the docstrings of these methods to tailor them to the specific subclass. This involves inserting a specific `head_doc`, replacing placeholders like `\"BaseAutoModelClass\"` with the subclass's name, and updating example checkpoints specific to the model type (e.g., `\"google-bert/bert-base-cased\"`). These modifications are crucial for providing accurate and relevant documentation and guidance specific to each subclass.\n\n3. **Re-assignment as Class Methods**: Once the functions are copied and customized, they are re-assigned to the subclass as class methods. This is done using `classmethod(from_config)` and `classmethod(from_pretrained)`. This step ensures that these methods, now tailored and documented specifically for the subclass, are callable on the subclass itself.\n\n4. **Preservation of Base Class Functionality**: Since the original methods are copied before being modified, the base class `_BaseAutoModelClass` retains its original `from_config` and `from_pretrained` methods without any changes. This isolation ensures that modifications specific to one subclass do not impact the behavior or documentation of these methods in the base class or any other subclasses.\n\nBy following this process, `auto_class_update` ensures that each subclass in the Transformers library can have methods that are specifically tailored to its requirements, both in terms of functionality and documentation, while preserving the integrity and functionality of the original methods from the base class. This approach enhances modularity and flexibility in the library, allowing developers to easily extend and customize model classes for various use cases.",
|
| 59 |
+
"expected_identifiers": [
|
| 60 |
+
"__code__",
|
| 61 |
+
"copy_func",
|
| 62 |
+
"from_config",
|
| 63 |
+
"from_pretrained"
|
| 64 |
+
]
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"repo": "huggingface/transformers",
|
| 68 |
+
"commit": "7bb1c99800d235791dace10305731f377db8077b",
|
| 69 |
+
"context_files": [
|
| 70 |
+
"huggingface/transformers/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py",
|
| 71 |
+
"huggingface/transformers/src/transformers/modeling_utils.py"
|
| 72 |
+
],
|
| 73 |
+
"question": "Given a system limitation of 5GB per file, how does `convert_checkpoint_from_megatron_to_transformers` manage the storage of a large model's data to comply with this restriction?",
|
| 74 |
+
"answer": "The `convert_checkpoint_from_megatron_to_transformers` function manages the storage of a large model's data to comply with a system limitation of 5GB per file by utilizing the `shard_checkpoint` function to split the model's state dictionary into multiple sub-checkpoints, each of which does not exceed the specified maximum size.\n\nHere's a detailed breakdown of how this is achieved:\n\n1. **Sharding Process**: The `shard_checkpoint` function is called within `convert_checkpoint_from_megatron_to_transformers` to handle the division of the model's weights into smaller parts or shards. This function takes the entire state dictionary of the model (`output_state_dict`) and a maximum shard size as inputs.\n\n2. **Size Calculation**: The function calculates the byte size of each tensor in the state dictionary using the `dtype_byte_size` function. This function determines the number of bytes each element of a tensor occupies in memory, based on the tensor's data type (`dtype`). This calculation is crucial as it helps in accurately assessing how much space each tensor will take when saved as part of a shard.\n\n3. **Iterative Sharding**: The `shard_checkpoint` iterates through each tensor in the state dictionary and adds them to the current shard until adding another tensor would exceed the maximum shard size (5GB in this scenario). When this limit is reached, a new shard is started. This ensures that no individual shard file exceeds the specified size limit.\n\n4. **Handling Oversized Tensors**: If a single tensor is larger than the maximum shard size, it is placed in its own shard. This is a necessary exception to prevent the function from failing due to an inability to split a tensor.\n\n5. **Saving Shards**: Each shard is saved as a separate file. The naming convention and indexing ensure that each part of the model can be identified and accessed correctly. The function also generates an index file if the model is split into multiple shards, detailing where each parameter is stored.\n\n6. **Parameter Mapping**: The function maintains a mapping (`weight_map`) of model parameters to their respective shard files. This mapping is crucial for efficiently loading the model from its sharded state.\n\nBy following these steps, the `convert_checkpoint_from_megatron_to_transformers` function ensures that each shard of the converted model adheres to the 5GB file size limit imposed by the system. This methodical sharding allows for efficient storage and handling of large models without exceeding system file size limitations.",
|
| 75 |
+
"expected_identifiers": [
|
| 76 |
+
"shard_checkpoint",
|
| 77 |
+
"dtype_byte_size",
|
| 78 |
+
"output_state_dict",
|
| 79 |
+
"weight_map"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"repo": "huggingface/transformers",
|
| 84 |
+
"commit": "7bb1c99800d235791dace10305731f377db8077b",
|
| 85 |
+
"context_files": [
|
| 86 |
+
"huggingface/transformers/src/transformers/quantizers/quantizer_hqq.py",
|
| 87 |
+
"huggingface/transformers/src/transformers/integrations/hqq.py"
|
| 88 |
+
],
|
| 89 |
+
"question": "In a scenario where a neural network model is being optimized for deployment, how does `HqqHfQuantizer._process_model_before_weight_loading` ensure that each linear module is appropriately and uniquely quantized?",
|
| 90 |
+
"answer": "In the scenario where a neural network model is being optimized for deployment using the `HqqHfQuantizer._process_model_before_weight_loading` function, the process of ensuring that each linear module is appropriately and uniquely quantized involves several key steps and functions.\n\n1. **Tagging Modules with Unique Identifiers**: The process begins with the `get_linear_tags` function, which is responsible for identifying and tagging all linear modules within the model. This function uses a `set` to collect the names of these modules, which inherently ensures that each tag is unique (since sets do not allow duplicates). This is crucial because it prevents any confusion or errors in later stages when quantization parameters are applied to these tags.\n\n2. **Applying Quantization Configuration**: Once the linear modules are tagged, the `prepare_for_hqq_linear` function takes over. This function receives a `quantization_config` and a list of modules not to convert. It first calls `autoname_modules` to ensure each module in the model has a unique name, and then retrieves the linear tags using `get_linear_tags`. The function then filters these tags to exclude any specified in `skip_modules` or `modules_to_not_convert`, ensuring that the quantization process is applied only to the relevant modules.\n\n3. **Mapping Quantization Parameters**: The core of the quantization process happens when `prepare_for_hqq_linear` maps the quantization parameters to each linear tag. This is done by creating a dictionary (`patch_params`) where each key is a linear tag and the value is the corresponding quantization parameter. If specific quantization parameters are not provided for a tag, a default configuration is applied. This mapping ensures that each linear module (identified uniquely by its tag) receives a tailored set of quantization parameters.\n\n4. **Updating Model Configuration**: After mapping the quantization parameters, the `prepare_for_hqq_linear` function updates the model's configuration to include these parameters, ensuring that each linear module's configuration reflects its unique quantization settings. This step is crucial for the actual quantization process, where linear modules might be replaced with their quantized counterparts (`HQQLinear`), depending on the configuration.\n\n5. **Final Verification and Logging**: The function checks if any linear modules have been replaced and logs a warning if no modules were found for quantization. This serves as a final check to ensure that the quantization process has been applied as expected.\n\nIn summary, the `HqqHfQuantizer._process_model_before_weight_loading` function ensures that each linear module is uniquely and appropriately quantized by meticulously tagging each module, applying a tailored quantization configuration, and updating the model to reflect these settings. This process is designed to optimize the model's performance for deployment, ensuring that each module operates efficiently and accurately under the constraints of quantization.",
|
| 91 |
+
"expected_identifiers": [
|
| 92 |
+
"get_linear_tags",
|
| 93 |
+
"autoname_modules",
|
| 94 |
+
"prepare_for_hqq_linear",
|
| 95 |
+
"patch_params"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"repo": "huggingface/transformers",
|
| 100 |
+
"commit": "7bb1c99800d235791dace10305731f377db8077b",
|
| 101 |
+
"context_files": [
|
| 102 |
+
"huggingface/transformers/src/transformers/models/esm/modeling_esmfold.py",
|
| 103 |
+
"huggingface/transformers/src/transformers/models/esm/openfold_utils/loss.py"
|
| 104 |
+
],
|
| 105 |
+
"question": "When analyzing a protein sequence with low complexity using `EsmForProteinFolding.forward`, how is the stability and definition of the output ensured?",
|
| 106 |
+
"answer": "When analyzing a protein sequence with low complexity using the `EsmForProteinFolding.forward` function, the stability and definition of the output are ensured through several key mechanisms embedded within the function's implementation, particularly in how it handles normalization and potential numerical instabilities.\n\n1. **Normalization of Residue Weights**: In the `compute_tm` function, residue weights are normalized by their sum, with the addition of a small constant `eps` (epsilon) to prevent division by zero. This is crucial when dealing with sequences of low complexity where certain residues might be overrepresented or underrepresented. The normalization step is represented in the code as:\n ```python\n normed_residue_mask = residue_weights / (eps + residue_weights.sum())\n ```\n Here, `eps` acts as a safeguard against division by zero, ensuring that the function remains numerically stable and produces defined outputs even when the sum of residue weights is extremely small or zero.\n\n2. **Weighted Average Calculation**: The function calculates a weighted average of the Template Modeling (TM) scores across different bins, which is critical for obtaining a reliable TM score. This is done using the normalized residue weights, ensuring that each residue's contribution is proportionate to its presence, thus maintaining accuracy and stability in the final score calculation:\n ```python\n per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)\n ```\n This step aggregates the TM scores across all residues, factoring in their normalized weights, which is particularly important in low complexity sequences where certain residues might dominate.\n\n3. **Handling of Edge Cases**: The use of `eps` in the normalization process is a direct method to handle edge cases, such as sequences with low complexity or unusual amino acid distributions. By ensuring that the denominator in the normalization step is never zero, the function avoids potential runtime errors (like NaN or infinite values), which could disrupt the analysis process.\n\n4. **Integration within `EsmForProteinFolding.forward`**: The stability and definition of outputs from the `EsmForProteinFolding.forward` function are further supported by how `compute_tm` integrates with other components of the model. The TM scores computed are used alongside other structural predictions, contributing to a comprehensive evaluation of the predicted protein structures. This integration ensures that the outputs are not only stable and defined but also meaningful in the context of protein structure prediction.\n\nIn summary, the `EsmForProteinFolding.forward` function ensures stable and defined outputs for protein structure predictions, particularly in scenarios involving low complexity sequences, by employing robust normalization techniques and handling potential numerical instabilities through the careful addition of a small epsilon value in critical calculations. This approach guarantees that the function can reliably process a wide range of input data without encountering computational errors.",
|
| 107 |
+
"expected_identifiers": [
|
| 108 |
+
"normed_residue_mask",
|
| 109 |
+
"eps",
|
| 110 |
+
"residue_weights / (eps + residue_weights.sum())",
|
| 111 |
+
"torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"repo": "huggingface/transformers",
|
| 116 |
+
"commit": "7bb1c99800d235791dace10305731f377db8077b",
|
| 117 |
+
"context_files": [
|
| 118 |
+
"huggingface/transformers/src/transformers/pipelines/question_answering.py",
|
| 119 |
+
"huggingface/transformers/src/transformers/data/processors/squad.py"
|
| 120 |
+
],
|
| 121 |
+
"question": "In a scenario where the textual data includes unusually lengthy paragraphs, how does `QuestionAnsweringPipeline.preprocess` ensure comprehensive coverage of all context tokens in the model's input sequences?",
|
| 122 |
+
"answer": "In scenarios where the textual data includes unusually lengthy paragraphs that exceed the model's maximum input length, the `QuestionAnsweringPipeline.preprocess` function ensures comprehensive coverage of all context tokens in the model's input sequences through a meticulous management of tokenization and handling of overflow tokens. This process is crucial for maintaining the integrity and continuity of the context information, which is essential for the model to accurately answer questions based on the provided context.\n\n### Step-by-Step Explanation:\n\n1. **Tokenization and Pairing**:\n The function begins by tokenizing the question and context separately. Depending on the tokenizer's configuration (`tokenizer.padding_side`), the question and context are arranged in a specific order (either question first or context first). This is handled in the lines where `encoded_inputs` is defined using `self.tokenizer(text, text_pair, ...)`. \n\n2. **Handling Long Contexts with Overflow Tokens**:\n The key parameter here is `return_overflowing_tokens=True` within the tokenizer call. This setting ensures that when the combined length of the question and context exceeds `max_seq_len`, the tokenizer automatically generates additional input sequences that contain the \"overflow\" tokens from the context. These sequences overlap by a number of tokens defined by `doc_stride`, which is calculated as `min(max_seq_len // 2, 128)`.\n\n3. **Creating Overlapping Spans**:\n The overlapping spans are crucial for ensuring that tokens near the boundaries of a sequence are also seen in different contextual surroundings, enhancing the model's ability to understand and answer questions about tokens that appear near the maximum sequence length limit. This overlap is managed by the `stride` parameter in the tokenizer, which is set to `doc_stride`.\n\n4. **Feature Construction**:\n For each span generated from the overflowing tokens, the function constructs a feature object that includes not only the token ids (`input_ids`) but also attention masks, token type ids, and a special mask (`p_mask`) which indicates which tokens can be part of an answer. The `p_mask` is particularly important as it helps the model distinguish between context tokens (potential answer locations) and non-context tokens (like those belonging to the question or special tokens).\n\n5. **Yielding Processed Features**:\n Each feature constructed from the spans is then yielded one by one, with additional metadata such as whether it is the last feature of the example. This is handled in the loop `for i, feature in enumerate(features):` where each feature is prepared according to the model's requirements, potentially converting them into tensors suitable for the model's computation framework (PyTorch or TensorFlow).\n\n### Conclusion:\n\nBy managing the tokenization and overflow tokens effectively, `QuestionAnsweringPipeline.preprocess` ensures that every token from a lengthy context is included in at least one input sequence to the model. This comprehensive coverage is achieved through the creation of multiple, overlapping input sequences that ensure no contextual information is lost, thereby enabling the model to perform accurately even with contexts that far exceed its maximum input length. This methodical approach is crucial for handling real-world data where lengthy paragraphs are common, ensuring robust model performance across varied datasets.",
|
| 123 |
+
"expected_identifiers": [
|
| 124 |
+
"return_overflowing_tokens=True",
|
| 125 |
+
"doc_stride",
|
| 126 |
+
"p_mask"
|
| 127 |
+
]
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"repo": "huggingface/transformers",
|
| 131 |
+
"commit": "7bb1c99800d235791dace10305731f377db8077b",
|
| 132 |
+
"context_files": [
|
| 133 |
+
"huggingface/transformers/examples/research_projects/movement-pruning/masked_run_squad.py",
|
| 134 |
+
"huggingface/transformers/src/transformers/data/processors/squad.py"
|
| 135 |
+
],
|
| 136 |
+
"question": "Given the challenge of training models on datasets with varying context lengths, how does `load_and_cache_examples` in `examples/research_projects/movement-pruning/masked_run_squad.py` ensure the production of a dataset that supports accurate answer position prediction by the models?",
|
| 137 |
+
"answer": "The `load_and_cache_examples` function in `examples/research_projects/movement-pruning/masked_run_squad.py` is designed to ensure that the dataset produced supports accurate answer position prediction by models, even when dealing with varying context lengths. This is achieved through a series of steps that involve careful handling and processing of the data, particularly when contexts exceed the model's maximum sequence length. Here's how the function manages this:\n\n1. **Data Loading and Caching**: Initially, the function checks if a cached version of the processed data exists. If it does, it loads the features, dataset, and examples directly from the cache, bypassing the need for reprocessing. This step ensures efficiency and consistency in data usage across different training runs.\n\n2. **Dynamic Data Processing**: If no cached data is available, the function processes the raw data to generate features suitable for model training. This involves tokenizing the text and converting the SQuAD examples into features that models can understand and learn from.\n\n3. **Handling Extended Contexts**: The core of handling varying context lengths lies in the `squad_convert_examples_to_features` function, which is called within `load_and_cache_examples`. This function uses `squad_convert_example_to_features` to process each example individually.\n\n4. **Segmentation and Token Index Adjustment**: In `squad_convert_example_to_features`, the context is potentially split into multiple spans if its length exceeds the model's maximum sequence length. This is crucial because it allows the model to handle long contexts by breaking them down into manageable parts. Each span is processed to ensure that the start and end positions of answers are correctly adjusted within the tokenized context. This adjustment is handled by the `_improve_answer_span` function, which ensures that the answer spans are accurately placed within the tokens, even if the context is segmented.\n\n5. **Feature Construction**: Each span is then converted into a set of features, including input IDs, attention masks, token type IDs, and the positions of the answers. Special care is taken to mark tokens that cannot be part of the answers (using a p_mask), and to identify the maximum context for each token, which is critical for understanding which part of the split context a token belongs to.\n\n6. **Dataset Compilation**: After processing, the features are compiled into a dataset format (either PyTorch or TensorFlow, based on the configuration). This dataset includes all necessary information for the model to learn from, including the context, the question, and the correct positions of the answers.\n\nBy carefully managing the tokenization, segmentation, and feature construction processes, `load_and_cache_examples` ensures that the dataset it produces allows models to accurately predict answer positions, regardless of the length of the context. This capability is essential for training robust question-answering models that can handle real-world data, where context lengths can vary significantly.",
|
| 138 |
+
"expected_identifiers": [
|
| 139 |
+
"squad_convert_examples_to_features",
|
| 140 |
+
"squad_convert_example_to_features",
|
| 141 |
+
"_improve_answer_span",
|
| 142 |
+
"p_mask"
|
| 143 |
+
]
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"repo": "huggingface/transformers",
|
| 147 |
+
"commit": "7bb1c99800d235791dace10305731f377db8077b",
|
| 148 |
+
"context_files": [
|
| 149 |
+
"huggingface/transformers/src/transformers/modeling_flax_utils.py",
|
| 150 |
+
"huggingface/transformers/src/transformers/utils/hub.py"
|
| 151 |
+
],
|
| 152 |
+
"question": "In a scenario where network conditions are suboptimal, how does `FlaxPreTrainedModel.from_pretrained` manage to reduce the model loading time?",
|
| 153 |
+
"answer": "In scenarios where network conditions are suboptimal, the `FlaxPreTrainedModel.from_pretrained` function effectively reduces model loading time by leveraging a sophisticated caching mechanism. This mechanism is crucial for managing the download and storage of model shards, ensuring efficient and faster model initialization.\n\n### Caching Mechanism:\nThe function first checks if the required model shards are already available in the local cache before attempting any network requests. This is achieved through the `try_to_load_from_cache` function, which inspects the cache for the presence of the last shard of the model. If the last shard is found in the cache, it is likely that all previous shards are also cached, thus avoiding the need for further network requests.\n\n### Download and Cache Management:\nIf the shards are not found in the cache, `FlaxPreTrainedModel.from_pretrained` proceeds to download them. Each shard's presence is verified using the `cached_file` function, which handles the downloading and caching of the shard if it is not already present. This function also supports resuming downloads, which is particularly useful in suboptimal network conditions where downloads might be interrupted.\n\n### Efficient Shard Handling:\nThe function `get_checkpoint_shard_files` is specifically designed to manage sharded model files. It reads the checkpoint index file to determine all the necessary shards for the model and then ensures each shard is either fetched from the cache or downloaded. This process is streamlined by the use of a progress bar (managed by `tqdm`), which provides visual feedback on the download process, enhancing user experience especially in network-constrained environments.\n\n### Impact of Caching on Model Loading Time:\nBy prioritizing cached shards, `FlaxPreTrainedModel.from_pretrained` significantly reduces the dependency on network bandwidth and stability. This is particularly beneficial in scenarios with limited network resources, as it minimizes the time spent in downloading model components. The caching mechanism ensures that once a model shard is downloaded and stored locally, subsequent loads of the same model will utilize the cached versions, thereby bypassing the network entirely and leading to much faster model initialization times.\n\n### Conclusion:\nThe caching strategy employed by `FlaxPreTrainedModel.from_pretrained` not only optimizes the use of network resources but also ensures consistent and reduced model loading times, regardless of network conditions. This approach is instrumental in scenarios where models need to be switched frequently or reloaded, providing a seamless and efficient user experience.",
|
| 154 |
+
"expected_identifiers": [
|
| 155 |
+
"try_to_load_from_cache",
|
| 156 |
+
"cached_file",
|
| 157 |
+
"get_checkpoint_shard_files",
|
| 158 |
+
"tqdm"
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"repo": "huggingface/transformers",
|
| 163 |
+
"commit": "7bb1c99800d235791dace10305731f377db8077b",
|
| 164 |
+
"context_files": [
|
| 165 |
+
"huggingface/transformers/examples/research_projects/information-gain-filtration/run_clm_igf.py",
|
| 166 |
+
"huggingface/transformers/examples/research_projects/information-gain-filtration/igf/igf.py"
|
| 167 |
+
],
|
| 168 |
+
"question": "In light of recent dataset size restrictions for training purposes, how does `generate_n_pairs` maintain compliance by ensuring the objective set adheres to the specified size and article length requirements?",
|
| 169 |
+
"answer": "The `generate_n_pairs` function ensures compliance with dataset size restrictions by meticulously managing the creation of the objective set through its subordinate function `generate_datasets`. This process is governed by specific parameters and conditions set within the code to meet the required criteria of size and article length.\n\n1. **Size of the Objective Set**: The function `generate_datasets` is designed to create an objective set that contains exactly the number of articles specified by the `number` parameter, which is passed from `generate_n_pairs` as `size_objective_set`. In the provided code, this value is set to 100. The loop within `generate_datasets` that populates the `objective_set` list includes a condition to break once the length of this list reaches the specified `number` (see the line `if len(objective_set) >= number: break`). This ensures that no more than 100 articles are added to the objective set, directly adhering to the dataset size restrictions.\n\n2. **Article Length Management**: The function also manages the length of each article in the objective set based on the `context_len` parameter. If `trim` is set to `True`, the function trims the articles to ensure they do not exceed the specified `context_len`. This is achieved by selecting a starting point randomly within the article and then slicing the article to obtain a segment of the specified `context_len` (see the line `objective_set.append(example[0, start : start + context_len])`). This ensures that each article in the objective set adheres to the length restrictions.\n\n3. **Compliance with Regulations**: By strictly controlling both the number of articles and their lengths as described, `generate_n_pairs` ensures that the objective set complies with new regulations requiring training datasets to contain no more than 100 articles, each of a specified maximum length. This compliance is crucial for ethical review and adherence to training dataset standards.\n\nIn summary, `generate_n_pairs` maintains compliance with dataset size and article length restrictions through careful implementation in `generate_datasets`, which explicitly controls the size of the objective set and trims articles to the required length based on the parameters provided. This methodical approach ensures that the objective set meets specified criteria, crucial for adhering to regulatory standards.",
|
| 170 |
+
"expected_identifiers": [
|
| 171 |
+
"generate_n_pairs",
|
| 172 |
+
"generate_datasets",
|
| 173 |
+
"size_objective_set",
|
| 174 |
+
"context_len"
|
| 175 |
+
]
|
| 176 |
+
}
|
| 177 |
+
]
|
requirements.txt
CHANGED
|
@@ -1,18 +1,20 @@
|
|
| 1 |
GitPython==3.1.43
|
| 2 |
Pygments==2.18.0
|
| 3 |
cohere==5.9.2
|
|
|
|
| 4 |
fastapi==0.112.2
|
| 5 |
gradio>=4.26.0
|
| 6 |
langchain==0.2.16
|
| 7 |
langchain-anthropic==0.1.23
|
| 8 |
langchain-cohere==0.2.4
|
| 9 |
langchain-community==0.2.17
|
| 10 |
-
langchain-core==0.2.
|
| 11 |
langchain-experimental==0.0.65
|
| 12 |
langchain-nvidia-ai-endpoints==0.2.2
|
| 13 |
langchain-ollama==0.1.3
|
| 14 |
langchain-openai==0.1.25
|
| 15 |
langchain-text-splitters==0.2.4
|
|
|
|
| 16 |
marqo==3.7.0
|
| 17 |
nbformat==5.10.4
|
| 18 |
openai==1.42.0
|
|
@@ -22,8 +24,10 @@ python-dotenv==1.0.1
|
|
| 22 |
requests==2.32.3
|
| 23 |
semchunk==2.2.0
|
| 24 |
sentence-transformers==3.1.0
|
|
|
|
| 25 |
tiktoken==0.7.0
|
| 26 |
tokenizers==0.19.1
|
| 27 |
transformers==4.44.2
|
| 28 |
tree-sitter==0.22.3
|
| 29 |
tree-sitter-language-pack==0.2.0
|
|
|
|
|
|
| 1 |
GitPython==3.1.43
|
| 2 |
Pygments==2.18.0
|
| 3 |
cohere==5.9.2
|
| 4 |
+
configargparse
|
| 5 |
fastapi==0.112.2
|
| 6 |
gradio>=4.26.0
|
| 7 |
langchain==0.2.16
|
| 8 |
langchain-anthropic==0.1.23
|
| 9 |
langchain-cohere==0.2.4
|
| 10 |
langchain-community==0.2.17
|
| 11 |
+
langchain-core==0.2.41
|
| 12 |
langchain-experimental==0.0.65
|
| 13 |
langchain-nvidia-ai-endpoints==0.2.2
|
| 14 |
langchain-ollama==0.1.3
|
| 15 |
langchain-openai==0.1.25
|
| 16 |
langchain-text-splitters==0.2.4
|
| 17 |
+
langchain-voyageai==0.1.1
|
| 18 |
marqo==3.7.0
|
| 19 |
nbformat==5.10.4
|
| 20 |
openai==1.42.0
|
|
|
|
| 24 |
requests==2.32.3
|
| 25 |
semchunk==2.2.0
|
| 26 |
sentence-transformers==3.1.0
|
| 27 |
+
tenacity==8.5.0
|
| 28 |
tiktoken==0.7.0
|
| 29 |
tokenizers==0.19.1
|
| 30 |
transformers==4.44.2
|
| 31 |
tree-sitter==0.22.3
|
| 32 |
tree-sitter-language-pack==0.2.0
|
| 33 |
+
voyageai==0.2.3
|
sage/chat.py
CHANGED
|
@@ -7,18 +7,15 @@ import logging
|
|
| 7 |
|
| 8 |
import configargparse
|
| 9 |
import gradio as gr
|
| 10 |
-
import pkg_resources
|
| 11 |
from dotenv import load_dotenv
|
| 12 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
| 13 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 14 |
-
from langchain.retrievers import ContextualCompressionRetriever
|
| 15 |
from langchain.schema import AIMessage, HumanMessage
|
| 16 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 17 |
|
| 18 |
import sage.config as sage_config
|
| 19 |
from sage.llm import build_llm_via_langchain
|
| 20 |
-
from sage.
|
| 21 |
-
from sage.vector_store import build_vector_store_from_args
|
| 22 |
|
| 23 |
load_dotenv()
|
| 24 |
|
|
@@ -26,12 +23,7 @@ load_dotenv()
|
|
| 26 |
def build_rag_chain(args):
|
| 27 |
"""Builds a RAG chain via LangChain."""
|
| 28 |
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
|
| 29 |
-
|
| 30 |
-
retriever_top_k = 5 if args.reranker_provider == "none" else 25
|
| 31 |
-
retriever = build_vector_store_from_args(args).as_retriever(top_k=retriever_top_k)
|
| 32 |
-
compressor = build_reranker(args.reranker_provider, args.reranker_model)
|
| 33 |
-
if compressor:
|
| 34 |
-
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
| 35 |
|
| 36 |
# Prompt to contextualize the latest query based on the chat history.
|
| 37 |
contextualize_q_system_prompt = (
|
|
@@ -83,6 +75,7 @@ def main():
|
|
| 83 |
|
| 84 |
arg_validators = [
|
| 85 |
sage_config.add_repo_args(parser),
|
|
|
|
| 86 |
sage_config.add_vector_store_args(parser),
|
| 87 |
sage_config.add_reranking_args(parser),
|
| 88 |
sage_config.add_llm_args(parser),
|
|
|
|
| 7 |
|
| 8 |
import configargparse
|
| 9 |
import gradio as gr
|
|
|
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
| 12 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
|
|
| 13 |
from langchain.schema import AIMessage, HumanMessage
|
| 14 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 15 |
|
| 16 |
import sage.config as sage_config
|
| 17 |
from sage.llm import build_llm_via_langchain
|
| 18 |
+
from sage.retriever import build_retriever_from_args
|
|
|
|
| 19 |
|
| 20 |
load_dotenv()
|
| 21 |
|
|
|
|
| 23 |
def build_rag_chain(args):
|
| 24 |
"""Builds a RAG chain via LangChain."""
|
| 25 |
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
|
| 26 |
+
retriever = build_retriever_from_args(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Prompt to contextualize the latest query based on the chat history.
|
| 29 |
contextualize_q_system_prompt = (
|
|
|
|
| 75 |
|
| 76 |
arg_validators = [
|
| 77 |
sage_config.add_repo_args(parser),
|
| 78 |
+
sage_config.add_embedding_args(parser),
|
| 79 |
sage_config.add_vector_store_args(parser),
|
| 80 |
sage_config.add_reranking_args(parser),
|
| 81 |
sage_config.add_llm_args(parser),
|
sage/chunker.py
CHANGED
|
@@ -299,6 +299,7 @@ class UniversalFileChunker(Chunker):
|
|
| 299 |
"""Chunks a file into smaller pieces, regardless of whether it's code or text."""
|
| 300 |
|
| 301 |
def __init__(self, max_tokens: int):
|
|
|
|
| 302 |
self.code_chunker = CodeFileChunker(max_tokens)
|
| 303 |
self.ipynb_chunker = IpynbFileChunker(self.code_chunker)
|
| 304 |
self.text_chunker = TextFileChunker(max_tokens)
|
|
|
|
| 299 |
"""Chunks a file into smaller pieces, regardless of whether it's code or text."""
|
| 300 |
|
| 301 |
def __init__(self, max_tokens: int):
|
| 302 |
+
self.max_tokens = max_tokens
|
| 303 |
self.code_chunker = CodeFileChunker(max_tokens)
|
| 304 |
self.ipynb_chunker = IpynbFileChunker(self.code_chunker)
|
| 305 |
self.text_chunker = TextFileChunker(max_tokens)
|
sage/config.py
CHANGED
|
@@ -28,6 +28,25 @@ OPENAI_DEFAULT_EMBEDDING_SIZE = {
|
|
| 28 |
"text-embedding-3-large": 3072,
|
| 29 |
}
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
def add_config_args(parser: ArgumentParser):
|
| 33 |
"""Adds configuration-related arguments to the parser."""
|
|
@@ -61,7 +80,7 @@ def add_repo_args(parser: ArgumentParser) -> Callable:
|
|
| 61 |
|
| 62 |
def add_embedding_args(parser: ArgumentParser) -> Callable:
|
| 63 |
"""Adds embedding-related arguments to the parser and returns a validator."""
|
| 64 |
-
parser.add("--embedding-provider", default="marqo", choices=["openai", "marqo"])
|
| 65 |
parser.add(
|
| 66 |
"--embedding-model",
|
| 67 |
type=str,
|
|
@@ -115,11 +134,17 @@ def add_vector_store_args(parser: ArgumentParser) -> Callable:
|
|
| 115 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 116 |
)
|
| 117 |
parser.add(
|
| 118 |
-
"--
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
help="
|
| 122 |
-
"retrieval.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
return validate_vector_store_args
|
| 125 |
|
|
@@ -169,6 +194,7 @@ def add_reranking_args(parser: ArgumentParser) -> Callable:
|
|
| 169 |
help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
|
| 170 |
"SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models",
|
| 171 |
)
|
|
|
|
| 172 |
# Trivial validator (nothing to check).
|
| 173 |
return lambda _: True
|
| 174 |
|
|
@@ -228,6 +254,33 @@ def _validate_openai_embedding_args(args):
|
|
| 228 |
raise ValueError(f"The maximum number of chunks per job is {OPENAI_MAX_TOKENS_PER_JOB}. Got {chunks_per_job}")
|
| 229 |
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
def _validate_marqo_embedding_args(args):
|
| 232 |
"""Validates the configuration of the Marqo batch embedder and sets defaults."""
|
| 233 |
if not args.embedding_model:
|
|
@@ -247,6 +300,8 @@ def validate_embedding_args(args):
|
|
| 247 |
"""Validates the configuration of the batch embedder and sets defaults."""
|
| 248 |
if args.embedding_provider == "openai":
|
| 249 |
_validate_openai_embedding_args(args)
|
|
|
|
|
|
|
| 250 |
elif args.embedding_provider == "marqo":
|
| 251 |
_validate_marqo_embedding_args(args)
|
| 252 |
else:
|
|
@@ -257,8 +312,11 @@ def validate_vector_store_args(args):
|
|
| 257 |
"""Validates the configuration of the vector store and sets defaults."""
|
| 258 |
|
| 259 |
if not args.index_namespace:
|
|
|
|
|
|
|
|
|
|
| 260 |
args.index_namespace = args.repo_id
|
| 261 |
-
if args.commit_hash:
|
| 262 |
args.index_namespace += "/" + args.commit_hash
|
| 263 |
if args.vector_store_provider == "marqo":
|
| 264 |
# Marqo doesn't allow slashes in the index namespace.
|
|
|
|
| 28 |
"text-embedding-3-large": 3072,
|
| 29 |
}
|
| 30 |
|
| 31 |
+
VOYAGE_MAX_CHUNKS_PER_BATCH = 128
|
| 32 |
+
|
| 33 |
+
def get_voyage_max_tokens_per_batch(model: str) -> int:
|
| 34 |
+
"""Returns the maximum number of tokens per batch for the Voyage model.
|
| 35 |
+
See https://docs.voyageai.com/reference/embeddings-api."""
|
| 36 |
+
if model == "voyage-3-lite":
|
| 37 |
+
return 1_000_000
|
| 38 |
+
if model in ["voyage-3", "voyage-2"]:
|
| 39 |
+
return 320_000
|
| 40 |
+
return 120_000
|
| 41 |
+
|
| 42 |
+
def get_voyage_embedding_size(model: str) -> int:
|
| 43 |
+
"""Returns the embedding size for the Voyage model. See https://docs.voyageai.com/docs/embeddings#model-choices."""
|
| 44 |
+
if model == "voyage-3-lite":
|
| 45 |
+
return 512
|
| 46 |
+
if model == "voyage-2-code":
|
| 47 |
+
return 1536
|
| 48 |
+
return 1024
|
| 49 |
+
|
| 50 |
|
| 51 |
def add_config_args(parser: ArgumentParser):
|
| 52 |
"""Adds configuration-related arguments to the parser."""
|
|
|
|
| 80 |
|
| 81 |
def add_embedding_args(parser: ArgumentParser) -> Callable:
|
| 82 |
"""Adds embedding-related arguments to the parser and returns a validator."""
|
| 83 |
+
parser.add("--embedding-provider", default="marqo", choices=["openai", "voyage", "marqo"])
|
| 84 |
parser.add(
|
| 85 |
"--embedding-model",
|
| 86 |
type=str,
|
|
|
|
| 134 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 135 |
)
|
| 136 |
parser.add(
|
| 137 |
+
"--retrieval-alpha",
|
| 138 |
+
default=0.5,
|
| 139 |
+
type=float,
|
| 140 |
+
help="Takes effect for Pinecone retriever only. The weight of the dense (embeddings-based) vs sparse (BM25) "
|
| 141 |
+
"encoder in the final retrieval score. A value of 0.0 means BM25 only, 1.0 means embeddings only.",
|
| 142 |
+
)
|
| 143 |
+
parser.add(
|
| 144 |
+
"--retriever-top-k",
|
| 145 |
+
default=25,
|
| 146 |
+
type=int,
|
| 147 |
+
help="The number of top documents to retrieve from the vector store."
|
| 148 |
)
|
| 149 |
return validate_vector_store_args
|
| 150 |
|
|
|
|
| 194 |
help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
|
| 195 |
"SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models",
|
| 196 |
)
|
| 197 |
+
parser.add("--reranker-top-k", default=5, help="The number of top documents to return after reranking.")
|
| 198 |
# Trivial validator (nothing to check).
|
| 199 |
return lambda _: True
|
| 200 |
|
|
|
|
| 254 |
raise ValueError(f"The maximum number of chunks per job is {OPENAI_MAX_TOKENS_PER_JOB}. Got {chunks_per_job}")
|
| 255 |
|
| 256 |
|
| 257 |
+
def _validate_voyage_embedding_args(args):
|
| 258 |
+
"""Validates the configuration of the Voyage batch embedder and sets defaults."""
|
| 259 |
+
if args.embedding_provider == "voyage" and not os.getenv("VOYAGE_API_KEY"):
|
| 260 |
+
raise ValueError("Please set the VOYAGE_API_KEY environment variable.")
|
| 261 |
+
|
| 262 |
+
if not args.embedding_model:
|
| 263 |
+
args.embedding_model = "voyage-code-2"
|
| 264 |
+
|
| 265 |
+
if not args.tokens_per_chunk:
|
| 266 |
+
# https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.
|
| 267 |
+
args.tokens_per_chunk = 800
|
| 268 |
+
|
| 269 |
+
if not args.chunks_per_batch:
|
| 270 |
+
args.chunks_per_batch = VOYAGE_MAX_CHUNKS_PER_BATCH
|
| 271 |
+
elif args.chunks_per_batch > VOYAGE_MAX_CHUNKS_PER_BATCH:
|
| 272 |
+
args.chunks_per_batch = VOYAGE_MAX_CHUNKS_PER_BATCH
|
| 273 |
+
logging.warning(f"Voyage enforces a limit of {VOYAGE_MAX_CHUNKS_PER_BATCH} chunks per batch. Overwriting.")
|
| 274 |
+
|
| 275 |
+
max_tokens = get_voyage_max_tokens_per_batch(args.embedding_model)
|
| 276 |
+
if args.tokens_per_chunk * args.chunks_per_batch > max_tokens:
|
| 277 |
+
raise ValueError(f"Voyage enforces a limit of {max_tokens} tokens per batch. "
|
| 278 |
+
"Reduce either --tokens-per-chunk or --chunks-per-batch.")
|
| 279 |
+
|
| 280 |
+
if not args.embedding_size:
|
| 281 |
+
args.embedding_size = get_voyage_embedding_size(args.embedding_model)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
def _validate_marqo_embedding_args(args):
|
| 285 |
"""Validates the configuration of the Marqo batch embedder and sets defaults."""
|
| 286 |
if not args.embedding_model:
|
|
|
|
| 300 |
"""Validates the configuration of the batch embedder and sets defaults."""
|
| 301 |
if args.embedding_provider == "openai":
|
| 302 |
_validate_openai_embedding_args(args)
|
| 303 |
+
elif args.embedding_provider == "voyage":
|
| 304 |
+
_validate_voyage_embedding_args(args)
|
| 305 |
elif args.embedding_provider == "marqo":
|
| 306 |
_validate_marqo_embedding_args(args)
|
| 307 |
else:
|
|
|
|
| 312 |
"""Validates the configuration of the vector store and sets defaults."""
|
| 313 |
|
| 314 |
if not args.index_namespace:
|
| 315 |
+
# Attempt to derive a default index namespace from the repository information.
|
| 316 |
+
if "repo_id" not in args:
|
| 317 |
+
raise ValueError("Please set a value for --index-namespace.")
|
| 318 |
args.index_namespace = args.repo_id
|
| 319 |
+
if "commit_hash" in args and args.commit_hash:
|
| 320 |
args.index_namespace += "/" + args.commit_hash
|
| 321 |
if args.vector_store_provider == "marqo":
|
| 322 |
# Marqo doesn't allow slashes in the index namespace.
|
sage/configs/remote.yaml
CHANGED
|
@@ -14,5 +14,4 @@ llm-provider: openai
|
|
| 14 |
llm-model: gpt-4
|
| 15 |
|
| 16 |
# Reranking
|
| 17 |
-
|
| 18 |
-
reranking-model: rerank-english-v3.0
|
|
|
|
| 14 |
llm-model: gpt-4
|
| 15 |
|
| 16 |
# Reranking
|
| 17 |
+
reranker-provider: nvidia
|
|
|
sage/embedder.py
CHANGED
|
@@ -9,7 +9,9 @@ from collections import Counter
|
|
| 9 |
from typing import Dict, Generator, List, Optional, Tuple
|
| 10 |
|
| 11 |
import marqo
|
|
|
|
| 12 |
from openai import OpenAI
|
|
|
|
| 13 |
|
| 14 |
from sage.chunker import Chunk, Chunker
|
| 15 |
from sage.constants import TEXT_FIELD
|
|
@@ -205,6 +207,72 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 205 |
}
|
| 206 |
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
class MarqoEmbedder(BatchEmbedder):
|
| 209 |
"""Embedder that uses the open-source Marqo vector search engine.
|
| 210 |
|
|
@@ -270,6 +338,8 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 270 |
def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
|
| 271 |
if args.embedding_provider == "openai":
|
| 272 |
return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
|
|
|
|
|
|
|
| 273 |
elif args.embedding_provider == "marqo":
|
| 274 |
return MarqoEmbedder(
|
| 275 |
data_manager, chunker, index_name=args.index_namespace, url=args.marqo_url, model=args.embedding_model
|
|
|
|
| 9 |
from typing import Dict, Generator, List, Optional, Tuple
|
| 10 |
|
| 11 |
import marqo
|
| 12 |
+
import requests
|
| 13 |
from openai import OpenAI
|
| 14 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
| 15 |
|
| 16 |
from sage.chunker import Chunk, Chunker
|
| 17 |
from sage.constants import TEXT_FIELD
|
|
|
|
| 207 |
}
|
| 208 |
|
| 209 |
|
| 210 |
+
class VoyageBatchEmbedder(BatchEmbedder):
|
| 211 |
+
"""Batch embedder that calls Voyage. See https://docs.voyageai.com/reference/embeddings-api."""
|
| 212 |
+
|
| 213 |
+
def __init__(self, data_manager: DataManager, chunker: Chunker, embedding_model: str):
|
| 214 |
+
self.data_manager = data_manager
|
| 215 |
+
self.chunker = chunker
|
| 216 |
+
self.embedding_model = embedding_model
|
| 217 |
+
self.embedding_data = []
|
| 218 |
+
|
| 219 |
+
def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
|
| 220 |
+
"""Issues batch embedding jobs for the entire dataset."""
|
| 221 |
+
batch = []
|
| 222 |
+
chunk_count = 0
|
| 223 |
+
|
| 224 |
+
for content, metadata in self.data_manager.walk():
|
| 225 |
+
chunks = self.chunker.chunk(content, metadata)
|
| 226 |
+
chunk_count += len(chunks)
|
| 227 |
+
batch.extend(chunks)
|
| 228 |
+
|
| 229 |
+
token_count = chunk_count * self.chunker.max_tokens
|
| 230 |
+
if token_count % 900_000 == 0:
|
| 231 |
+
logging.info("Pausing for 60 seconds to avoid rate limiting...")
|
| 232 |
+
time.sleep(60) # Voyage API rate limits to 1m tokens per minute; we'll pause every 900k tokens.
|
| 233 |
+
|
| 234 |
+
if len(batch) > chunks_per_batch:
|
| 235 |
+
for i in range(0, len(batch), chunks_per_batch):
|
| 236 |
+
sub_batch = batch[i : i + chunks_per_batch]
|
| 237 |
+
logging.info("Embedding %d chunks...", len(sub_batch))
|
| 238 |
+
result = self._make_batch_request(sub_batch)
|
| 239 |
+
for chunk, datum in zip(sub_batch, result["data"]):
|
| 240 |
+
self.embedding_data.append((chunk.metadata, datum["embedding"]))
|
| 241 |
+
batch = []
|
| 242 |
+
|
| 243 |
+
# Finally, commit the last batch.
|
| 244 |
+
if batch:
|
| 245 |
+
logging.info("Embedding %d chunks...", len(batch))
|
| 246 |
+
result = self._make_batch_request(batch)
|
| 247 |
+
for chunk, datum in zip(batch, result["data"]):
|
| 248 |
+
self.embedding_data.append((chunk.metadata, datum["embedding"]))
|
| 249 |
+
|
| 250 |
+
logging.info(f"Successfully embedded {chunk_count} chunks.")
|
| 251 |
+
|
| 252 |
+
def embeddings_are_ready(self, *args, **kwargs) -> bool:
|
| 253 |
+
"""Checks whether the batch embedding jobs are done."""
|
| 254 |
+
# The Voyage API is synchronous, so once embed_dataset() returns, the embeddings are ready.
|
| 255 |
+
return True
|
| 256 |
+
|
| 257 |
+
def download_embeddings(self, *args, **kwargs) -> Generator[Vector, None, None]:
|
| 258 |
+
"""Yields (chunk_metadata, embedding) pairs for each chunk in the dataset."""
|
| 259 |
+
for chunk_metadata, embedding in self.embedding_data:
|
| 260 |
+
yield (chunk_metadata, embedding)
|
| 261 |
+
|
| 262 |
+
@retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(6))
|
| 263 |
+
def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
|
| 264 |
+
"""Makes a batch request to the Voyage API with exponential backoff when we hit rate limits."""
|
| 265 |
+
url = "https://api.voyageai.com/v1/embeddings"
|
| 266 |
+
headers = {"Authorization": f"Bearer {os.environ['VOYAGE_API_KEY']}", "Content-Type": "application/json"}
|
| 267 |
+
payload = {"input": [chunk.content for chunk in chunks], "model": self.embedding_model}
|
| 268 |
+
|
| 269 |
+
response = requests.post(url, json=payload, headers=headers)
|
| 270 |
+
if not response.status_code == 200:
|
| 271 |
+
raise ValueError(f"Failed to make batch request. Response: {response.text}")
|
| 272 |
+
|
| 273 |
+
return response.json()
|
| 274 |
+
|
| 275 |
+
|
| 276 |
class MarqoEmbedder(BatchEmbedder):
|
| 277 |
"""Embedder that uses the open-source Marqo vector search engine.
|
| 278 |
|
|
|
|
| 338 |
def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
|
| 339 |
if args.embedding_provider == "openai":
|
| 340 |
return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
|
| 341 |
+
elif args.embedding_provider == "voyage":
|
| 342 |
+
return VoyageBatchEmbedder(data_manager, chunker, args.embedding_model)
|
| 343 |
elif args.embedding_provider == "marqo":
|
| 344 |
return MarqoEmbedder(
|
| 345 |
data_manager, chunker, index_name=args.index_namespace, url=args.marqo_url, model=args.embedding_model
|
sage/index.py
CHANGED
|
@@ -90,7 +90,7 @@ def main():
|
|
| 90 |
time.sleep(30)
|
| 91 |
|
| 92 |
logging.info("Moving embeddings to the repo vector store...")
|
| 93 |
-
repo_vector_store = build_vector_store_from_args(args)
|
| 94 |
repo_vector_store.ensure_exists()
|
| 95 |
repo_vector_store.upsert(repo_embedder.download_embeddings(repo_jobs_file))
|
| 96 |
|
|
@@ -101,7 +101,7 @@ def main():
|
|
| 101 |
time.sleep(30)
|
| 102 |
|
| 103 |
logging.info("Moving embeddings to the issues vector store...")
|
| 104 |
-
issues_vector_store = build_vector_store_from_args(args)
|
| 105 |
issues_vector_store.ensure_exists()
|
| 106 |
issues_vector_store.upsert(issues_embedder.download_embeddings(issues_jobs_file))
|
| 107 |
|
|
|
|
| 90 |
time.sleep(30)
|
| 91 |
|
| 92 |
logging.info("Moving embeddings to the repo vector store...")
|
| 93 |
+
repo_vector_store = build_vector_store_from_args(args, repo_manager)
|
| 94 |
repo_vector_store.ensure_exists()
|
| 95 |
repo_vector_store.upsert(repo_embedder.download_embeddings(repo_jobs_file))
|
| 96 |
|
|
|
|
| 101 |
time.sleep(30)
|
| 102 |
|
| 103 |
logging.info("Moving embeddings to the issues vector store...")
|
| 104 |
+
issues_vector_store = build_vector_store_from_args(args, issues_manager)
|
| 105 |
issues_vector_store.ensure_exists()
|
| 106 |
issues_vector_store.upsert(issues_embedder.download_embeddings(issues_jobs_file))
|
| 107 |
|
sage/reranker.py
CHANGED
|
@@ -8,6 +8,7 @@ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
|
| 8 |
from langchain_community.document_compressors import JinaRerank
|
| 9 |
from langchain_core.documents import BaseDocumentCompressor
|
| 10 |
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class RerankerProvider(Enum):
|
|
@@ -16,27 +17,33 @@ class RerankerProvider(Enum):
|
|
| 16 |
COHERE = "cohere"
|
| 17 |
NVIDIA = "nvidia"
|
| 18 |
JINA = "jina"
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
-
def build_reranker(provider: str, model: Optional[str] = None,
|
| 22 |
if provider == RerankerProvider.NONE.value:
|
| 23 |
return None
|
| 24 |
if provider == RerankerProvider.HUGGINGFACE.value:
|
| 25 |
model = model or "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 26 |
encoder_model = HuggingFaceCrossEncoder(model_name=model)
|
| 27 |
-
return CrossEncoderReranker(model=encoder_model, top_n=
|
| 28 |
if provider == RerankerProvider.COHERE.value:
|
| 29 |
if not os.environ.get("COHERE_API_KEY"):
|
| 30 |
raise ValueError("Please set the COHERE_API_KEY environment variable")
|
| 31 |
model = model or "rerank-english-v3.0"
|
| 32 |
-
return CohereRerank(model=model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=
|
| 33 |
if provider == RerankerProvider.NVIDIA.value:
|
| 34 |
if not os.environ.get("NVIDIA_API_KEY"):
|
| 35 |
raise ValueError("Please set the NVIDIA_API_KEY environment variable")
|
| 36 |
model = model or "nvidia/nv-rerankqa-mistral-4b-v3"
|
| 37 |
-
return NVIDIARerank(model=model, api_key=os.environ.get("NVIDIA_API_KEY"), top_n=
|
| 38 |
if provider == RerankerProvider.JINA.value:
|
| 39 |
if not os.environ.get("JINA_API_KEY"):
|
| 40 |
raise ValueError("Please set the JINA_API_KEY environment variable")
|
| 41 |
-
return JinaRerank(top_n=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
raise ValueError(f"Invalid reranker provider: {provider}")
|
|
|
|
| 8 |
from langchain_community.document_compressors import JinaRerank
|
| 9 |
from langchain_core.documents import BaseDocumentCompressor
|
| 10 |
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
| 11 |
+
from langchain_voyageai import VoyageAIRerank
|
| 12 |
|
| 13 |
|
| 14 |
class RerankerProvider(Enum):
|
|
|
|
| 17 |
COHERE = "cohere"
|
| 18 |
NVIDIA = "nvidia"
|
| 19 |
JINA = "jina"
|
| 20 |
+
VOYAGE = "voyage"
|
| 21 |
|
| 22 |
|
| 23 |
+
def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor:
|
| 24 |
if provider == RerankerProvider.NONE.value:
|
| 25 |
return None
|
| 26 |
if provider == RerankerProvider.HUGGINGFACE.value:
|
| 27 |
model = model or "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 28 |
encoder_model = HuggingFaceCrossEncoder(model_name=model)
|
| 29 |
+
return CrossEncoderReranker(model=encoder_model, top_n=top_k)
|
| 30 |
if provider == RerankerProvider.COHERE.value:
|
| 31 |
if not os.environ.get("COHERE_API_KEY"):
|
| 32 |
raise ValueError("Please set the COHERE_API_KEY environment variable")
|
| 33 |
model = model or "rerank-english-v3.0"
|
| 34 |
+
return CohereRerank(model=model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=top_k)
|
| 35 |
if provider == RerankerProvider.NVIDIA.value:
|
| 36 |
if not os.environ.get("NVIDIA_API_KEY"):
|
| 37 |
raise ValueError("Please set the NVIDIA_API_KEY environment variable")
|
| 38 |
model = model or "nvidia/nv-rerankqa-mistral-4b-v3"
|
| 39 |
+
return NVIDIARerank(model=model, api_key=os.environ.get("NVIDIA_API_KEY"), top_n=top_k, truncate="END")
|
| 40 |
if provider == RerankerProvider.JINA.value:
|
| 41 |
if not os.environ.get("JINA_API_KEY"):
|
| 42 |
raise ValueError("Please set the JINA_API_KEY environment variable")
|
| 43 |
+
return JinaRerank(top_n=top_k)
|
| 44 |
+
if provider == RerankerProvider.VOYAGE.value:
|
| 45 |
+
if not os.environ.get("VOYAGE_API_KEY"):
|
| 46 |
+
raise ValueError("Please set the VOYAGE_API_KEY environment variable")
|
| 47 |
+
model = model or "rerank-1"
|
| 48 |
+
return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k)
|
| 49 |
raise ValueError(f"Invalid reranker provider: {provider}")
|
sage/retriever.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
| 2 |
+
from langchain_openai import OpenAIEmbeddings
|
| 3 |
+
from langchain_voyageai import VoyageAIEmbeddings
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from sage.reranker import build_reranker
|
| 7 |
+
from sage.vector_store import build_vector_store_from_args
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_retriever_from_args(args):
|
| 11 |
+
"""Builds a retriever (with optional reranking) from command-line arguments."""
|
| 12 |
+
|
| 13 |
+
if args.embedding_provider == "openai":
|
| 14 |
+
embeddings = OpenAIEmbeddings(model=args.embedding_model)
|
| 15 |
+
elif args.embedding_provider == "voyage":
|
| 16 |
+
embeddings = VoyageAIEmbeddings(model=args.embedding_model)
|
| 17 |
+
else:
|
| 18 |
+
embeddings = None
|
| 19 |
+
|
| 20 |
+
retriever = build_vector_store_from_args(args).as_retriever(top_k=args.retriever_top_k, embeddings=embeddings)
|
| 21 |
+
|
| 22 |
+
reranker = build_reranker(args.reranker_provider, args.reranker_model, args.reranker_top_k)
|
| 23 |
+
if reranker:
|
| 24 |
+
retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
|
| 25 |
+
return retriever
|
sage/vector_store.py
CHANGED
|
@@ -1,19 +1,22 @@
|
|
| 1 |
"""Vector store abstraction and implementations."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
from functools import cached_property
|
| 5 |
-
from typing import Dict, Generator, List, Tuple
|
| 6 |
|
| 7 |
import marqo
|
| 8 |
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
| 9 |
from langchain_community.vectorstores import Marqo
|
| 10 |
from langchain_community.vectorstores import Pinecone as LangChainPinecone
|
| 11 |
from langchain_core.documents import Document
|
| 12 |
-
from
|
| 13 |
from pinecone import Pinecone, ServerlessSpec
|
| 14 |
from pinecone_text.sparse import BM25Encoder
|
| 15 |
|
| 16 |
from sage.constants import TEXT_FIELD
|
|
|
|
| 17 |
|
| 18 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 19 |
|
|
@@ -41,24 +44,40 @@ class VectorStore(ABC):
|
|
| 41 |
self.upsert_batch(batch)
|
| 42 |
|
| 43 |
@abstractmethod
|
| 44 |
-
def as_retriever(self, top_k: int):
|
| 45 |
"""Converts the vector store to a LangChain retriever object."""
|
| 46 |
|
| 47 |
|
| 48 |
class PineconeVectorStore(VectorStore):
|
| 49 |
"""Vector store implementation using Pinecone."""
|
| 50 |
|
| 51 |
-
def __init__(self, index_name: str, namespace: str, dimension: int,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
self.index_name = index_name
|
| 53 |
self.dimension = dimension
|
| 54 |
self.client = Pinecone()
|
| 55 |
self.namespace = namespace
|
| 56 |
-
self.
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
@cached_property
|
| 64 |
def index(self):
|
|
@@ -84,7 +103,7 @@ class PineconeVectorStore(VectorStore):
|
|
| 84 |
name=self.index_name,
|
| 85 |
dimension=self.dimension,
|
| 86 |
# See https://www.pinecone.io/learn/hybrid-search-intro/
|
| 87 |
-
metric="dotproduct" if self.
|
| 88 |
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
|
| 89 |
)
|
| 90 |
|
|
@@ -98,19 +117,19 @@ class PineconeVectorStore(VectorStore):
|
|
| 98 |
|
| 99 |
self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
|
| 100 |
|
| 101 |
-
def as_retriever(self, top_k: int):
|
| 102 |
if self.bm25_encoder:
|
| 103 |
return PineconeHybridSearchRetriever(
|
| 104 |
-
embeddings=
|
| 105 |
sparse_encoder=self.bm25_encoder,
|
| 106 |
index=self.index,
|
| 107 |
namespace=self.namespace,
|
| 108 |
top_k=top_k,
|
| 109 |
-
alpha=
|
| 110 |
)
|
| 111 |
|
| 112 |
return LangChainPinecone.from_existing_index(
|
| 113 |
-
index_name=self.index_name, embedding=
|
| 114 |
).as_retriever(search_kwargs={"k": top_k})
|
| 115 |
|
| 116 |
|
|
@@ -128,7 +147,8 @@ class MarqoVectorStore(VectorStore):
|
|
| 128 |
# Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
|
| 129 |
pass
|
| 130 |
|
| 131 |
-
def as_retriever(self, top_k: int):
|
|
|
|
| 132 |
vectorstore = Marqo(client=self.client, index_name=self.index_name)
|
| 133 |
|
| 134 |
# Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
|
|
@@ -146,14 +166,32 @@ class MarqoVectorStore(VectorStore):
|
|
| 146 |
return vectorstore.as_retriever(search_kwargs={"k": top_k})
|
| 147 |
|
| 148 |
|
| 149 |
-
def build_vector_store_from_args(args: dict) -> VectorStore:
|
| 150 |
-
"""Builds a vector store from the given command-line arguments.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
if args.vector_store_provider == "pinecone":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
return PineconeVectorStore(
|
| 153 |
index_name=args.pinecone_index_name,
|
| 154 |
namespace=args.index_namespace,
|
| 155 |
dimension=args.embedding_size if "embedding_size" in args else None,
|
| 156 |
-
|
|
|
|
| 157 |
)
|
| 158 |
elif args.vector_store_provider == "marqo":
|
| 159 |
return MarqoVectorStore(url=args.marqo_url, index_name=args.index_namespace)
|
|
|
|
| 1 |
"""Vector store abstraction and implementations."""
|
| 2 |
|
| 3 |
+
import os
|
| 4 |
+
import logging
|
| 5 |
from abc import ABC, abstractmethod
|
| 6 |
from functools import cached_property
|
| 7 |
+
from typing import Dict, Generator, List, Optional, Tuple
|
| 8 |
|
| 9 |
import marqo
|
| 10 |
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
| 11 |
from langchain_community.vectorstores import Marqo
|
| 12 |
from langchain_community.vectorstores import Pinecone as LangChainPinecone
|
| 13 |
from langchain_core.documents import Document
|
| 14 |
+
from langchain_core.embeddings import Embeddings
|
| 15 |
from pinecone import Pinecone, ServerlessSpec
|
| 16 |
from pinecone_text.sparse import BM25Encoder
|
| 17 |
|
| 18 |
from sage.constants import TEXT_FIELD
|
| 19 |
+
from sage.data_manager import DataManager
|
| 20 |
|
| 21 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 22 |
|
|
|
|
| 44 |
self.upsert_batch(batch)
|
| 45 |
|
| 46 |
@abstractmethod
|
| 47 |
+
def as_retriever(self, top_k: int, embeddings: Embeddings):
|
| 48 |
"""Converts the vector store to a LangChain retriever object."""
|
| 49 |
|
| 50 |
|
| 51 |
class PineconeVectorStore(VectorStore):
|
| 52 |
"""Vector store implementation using Pinecone."""
|
| 53 |
|
| 54 |
+
def __init__(self, index_name: str, namespace: str, dimension: int, alpha: float, bm25_cache: Optional[str] = None):
|
| 55 |
+
"""
|
| 56 |
+
Args:
|
| 57 |
+
index_name: The name of the Pinecone index to use. If it doesn't exist already, we'll create it.
|
| 58 |
+
namespace: The namespace within the index to use.
|
| 59 |
+
dimension: The dimension of the vectors.
|
| 60 |
+
alpha: The alpha parameter for hybrid search: alpha == 1.0 means pure dense search, alpha == 0.0 means pure
|
| 61 |
+
BM25, and 0.0 < alpha < 1.0 means a hybrid of the two.
|
| 62 |
+
bm25_cache: The path to the BM25 encoder file. If not specified, we'll use the default BM25 (fitted on the
|
| 63 |
+
MS MARCO dataset).
|
| 64 |
+
"""
|
| 65 |
self.index_name = index_name
|
| 66 |
self.dimension = dimension
|
| 67 |
self.client = Pinecone()
|
| 68 |
self.namespace = namespace
|
| 69 |
+
self.alpha = alpha
|
| 70 |
+
|
| 71 |
+
if alpha < 1.0:
|
| 72 |
+
if bm25_cache and os.path.exists(bm25_cache):
|
| 73 |
+
logging.info("Loading BM25 encoder from cache.")
|
| 74 |
+
self.bm25_encoder = BM25Encoder()
|
| 75 |
+
self.bm25_encoder.load(path=bm25_cache)
|
| 76 |
+
else:
|
| 77 |
+
logging.info("Using default BM25 encoder (fitted to MS MARCO).")
|
| 78 |
+
self.bm25_encoder = BM25Encoder.default()
|
| 79 |
+
else:
|
| 80 |
+
self.bm25_encoder = None
|
| 81 |
|
| 82 |
@cached_property
|
| 83 |
def index(self):
|
|
|
|
| 103 |
name=self.index_name,
|
| 104 |
dimension=self.dimension,
|
| 105 |
# See https://www.pinecone.io/learn/hybrid-search-intro/
|
| 106 |
+
metric="dotproduct" if self.bm25_encoder else "cosine",
|
| 107 |
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
|
| 108 |
)
|
| 109 |
|
|
|
|
| 117 |
|
| 118 |
self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
|
| 119 |
|
| 120 |
+
def as_retriever(self, top_k: int, embeddings: Embeddings):
|
| 121 |
if self.bm25_encoder:
|
| 122 |
return PineconeHybridSearchRetriever(
|
| 123 |
+
embeddings=embeddings,
|
| 124 |
sparse_encoder=self.bm25_encoder,
|
| 125 |
index=self.index,
|
| 126 |
namespace=self.namespace,
|
| 127 |
top_k=top_k,
|
| 128 |
+
alpha=self.alpha,
|
| 129 |
)
|
| 130 |
|
| 131 |
return LangChainPinecone.from_existing_index(
|
| 132 |
+
index_name=self.index_name, embedding=embeddings, namespace=self.namespace
|
| 133 |
).as_retriever(search_kwargs={"k": top_k})
|
| 134 |
|
| 135 |
|
|
|
|
| 147 |
# Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
|
| 148 |
pass
|
| 149 |
|
| 150 |
+
def as_retriever(self, top_k: int, embeddings: Embeddings = None):
|
| 151 |
+
del embeddings # Unused; The Marqo vector store is also an embedder.
|
| 152 |
vectorstore = Marqo(client=self.client, index_name=self.index_name)
|
| 153 |
|
| 154 |
# Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
|
|
|
|
| 166 |
return vectorstore.as_retriever(search_kwargs={"k": top_k})
|
| 167 |
|
| 168 |
|
| 169 |
+
def build_vector_store_from_args(args: dict, data_manager: Optional[DataManager] = None) -> VectorStore:
|
| 170 |
+
"""Builds a vector store from the given command-line arguments.
|
| 171 |
+
|
| 172 |
+
When `data_manager` is specified and hybrid retrieval is requested, we'll use it to fit a BM25 encoder on the corpus
|
| 173 |
+
of documents.
|
| 174 |
+
"""
|
| 175 |
if args.vector_store_provider == "pinecone":
|
| 176 |
+
bm25_cache = os.path.join(".bm25_cache", args.index_namespace, "bm25_encoder.json")
|
| 177 |
+
|
| 178 |
+
if not os.path.exists(bm25_cache) and data_manager:
|
| 179 |
+
logging.info("Fitting BM25 encoder on the corpus...")
|
| 180 |
+
corpus = [content for content, _ in data_manager.walk()]
|
| 181 |
+
bm25_encoder = BM25Encoder()
|
| 182 |
+
bm25_encoder.fit(corpus)
|
| 183 |
+
# Make sure the folder exists, before we dump the encoder.
|
| 184 |
+
bm25_folder = os.path.dirname(bm25_cache)
|
| 185 |
+
if not os.path.exists(bm25_folder):
|
| 186 |
+
os.makedirs(bm25_folder)
|
| 187 |
+
bm25_encoder.dump(bm25_cache)
|
| 188 |
+
|
| 189 |
return PineconeVectorStore(
|
| 190 |
index_name=args.pinecone_index_name,
|
| 191 |
namespace=args.index_namespace,
|
| 192 |
dimension=args.embedding_size if "embedding_size" in args else None,
|
| 193 |
+
alpha=args.retrieval_alpha,
|
| 194 |
+
bm25_cache=bm25_cache,
|
| 195 |
)
|
| 196 |
elif args.vector_store_provider == "marqo":
|
| 197 |
return MarqoVectorStore(url=args.marqo_url, index_name=args.index_namespace)
|