juliaturc commited on
Commit
27dd60a
·
1 Parent(s): ceffe55

Add an LLM retriever (#70)

Browse files
README.md CHANGED
@@ -225,10 +225,16 @@ Currently, we support the following types of retrieval:
225
  - Note this is not available when running locally, only when using Pinecone as a vector store.
226
  - Contrary to [Anthropic's findings](https://www.anthropic.com/news/contextual-retrieval), we find that BM25 is actually damaging performance *on codebases*, because it gives undeserved advantage to Markdown files.
227
 
228
- - **Multi-query retrieval** performs multiple query rewrites, makes a separate retrieval call for each, and takes the union of the retrieved documents. You can activate it by passing `--multi-query-retrieval`.
229
 
230
  - We find that [on our benchmark](benchmark/retrieval/README.md) this only marginally improves retrieval quality (from 0.44 to 0.46 R-precision) while being significantly slower and more expensive due to LLM calls. But your mileage may vary.
231
 
 
 
 
 
 
 
232
  </details>
233
 
234
  # Why chat with a codebase?
 
225
  - Note this is not available when running locally, only when using Pinecone as a vector store.
226
  - Contrary to [Anthropic's findings](https://www.anthropic.com/news/contextual-retrieval), we find that BM25 is actually damaging performance *on codebases*, because it gives undeserved advantage to Markdown files.
227
 
228
+ - **Multi-query retrieval** performs multiple query rewrites, makes a separate retrieval call for each, and takes the union of the retrieved documents. You can activate it by passing `--multi-query-retrieval`. This can be combined with both vanilla and hybrid RAG.
229
 
230
  - We find that [on our benchmark](benchmark/retrieval/README.md) this only marginally improves retrieval quality (from 0.44 to 0.46 R-precision) while being significantly slower and more expensive due to LLM calls. But your mileage may vary.
231
 
232
+ - **LLM-only retrieval** completely circumvents indexing the codebase. We simply enumerate all file paths and pass them to an LLM together with the user query. We ask the LLM which files are likely to be relevant for the user query, solely based on their filenames. You can activate it by passing `--llm-retriever`.
233
+
234
+ - We find that [on our benchmark](benchmark/retrieval/README.md) the performance is comparable with vector database solutions (R-precision is 0.44 for both). This is quite remarkable, since we've saved so much effort by not indexing the codebase. However, we are reluctant to claim that these findings generalize, for the following reasons:
235
+ - Our (artificial) dataset occasionally contains explicit path names in the query, making it trivial for the LLM. Sample query: *"Alice is managing a series of machine learning experiments. Please explain in detail how `main` in `examples/pytorch/image-pretraining/run_mim.py` allows her to organize the outputs of each experiment in separate directories."*
236
+ - Our benchmark focuses on the Transformers library, which is well-maintained and the file paths are often meaningful. This might not be the case for all codebases.
237
+
238
  </details>
239
 
240
  # Why chat with a codebase?
benchmarks/retrieval/retrieve.py CHANGED
@@ -41,24 +41,12 @@ def main():
41
  )
42
 
43
  parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
44
- sage.config.add_config_args(parser)
45
- sage.config.add_llm_args(parser) # Needed for --multi-query-retriever, which rewrites the query with an LLM.
46
- sage.config.add_embedding_args(parser)
47
- sage.config.add_vector_store_args(parser)
48
- sage.config.add_reranking_args(parser)
49
- sage.config.add_repo_args(parser)
50
- sage.config.add_indexing_args(parser)
51
  args = parser.parse_args()
52
- sage.config.validate_vector_store_args(args)
53
- repo_manager = GitHubRepoManager(
54
- args.repo_id,
55
- commit_hash=args.commit_hash,
56
- access_token=os.getenv("GITHUB_TOKEN"),
57
- local_dir=args.local_dir,
58
- inclusion_file=args.include,
59
- exclusion_file=args.exclude,
60
- )
61
- repo_manager.download()
62
  retriever = build_retriever_from_args(args, repo_manager)
63
 
64
  with open(args.benchmark, "r") as f:
 
41
  )
42
 
43
  parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
44
+
45
+ validator = sage.config.add_all_args(parser)
 
 
 
 
 
46
  args = parser.parse_args()
47
+ validator(args)
48
+
49
+ repo_manager = GitHubRepoManager.from_args(args)
 
 
 
 
 
 
 
50
  retriever = build_retriever_from_args(args, repo_manager)
51
 
52
  with open(args.benchmark, "r") as f:
pyproject.toml CHANGED
@@ -23,6 +23,7 @@ classifiers = [
23
  dependencies = [
24
  "GitPython==3.1.43",
25
  "Pygments==2.18.0",
 
26
  "cohere==5.9.2",
27
  "configargparse",
28
  "fastapi==0.112.2",
@@ -46,6 +47,7 @@ dependencies = [
46
  "pinecone==5.0.1",
47
  "pinecone-text==0.9.0",
48
  "python-dotenv==1.0.1",
 
49
  "requests==2.32.3",
50
  "semchunk==2.2.0",
51
  "sentence-transformers==3.1.0",
 
23
  dependencies = [
24
  "GitPython==3.1.43",
25
  "Pygments==2.18.0",
26
+ "anytree==2.12.1",
27
  "cohere==5.9.2",
28
  "configargparse",
29
  "fastapi==0.112.2",
 
47
  "pinecone==5.0.1",
48
  "pinecone-text==0.9.0",
49
  "python-dotenv==1.0.1",
50
+ "python-Levenshtein==0.26.0",
51
  "requests==2.32.3",
52
  "semchunk==2.2.0",
53
  "sentence-transformers==3.1.0",
sage/chat.py CHANGED
@@ -71,20 +71,10 @@ def main():
71
  default=False,
72
  help="Whether to make the gradio app publicly accessible.",
73
  )
74
- sage_config.add_config_args(parser)
75
-
76
- arg_validators = [
77
- sage_config.add_repo_args(parser),
78
- sage_config.add_embedding_args(parser),
79
- sage_config.add_vector_store_args(parser),
80
- sage_config.add_reranking_args(parser),
81
- sage_config.add_llm_args(parser),
82
- ]
83
 
 
84
  args = parser.parse_args()
85
-
86
- for validator in arg_validators:
87
- validator(args)
88
 
89
  rag_chain = build_rag_chain(args)
90
 
 
71
  default=False,
72
  help="Whether to make the gradio app publicly accessible.",
73
  )
 
 
 
 
 
 
 
 
 
74
 
75
+ validator = sage_config.add_all_args(parser)
76
  args = parser.parse_args()
77
+ validator(args)
 
 
78
 
79
  rag_chain = build_rag_chain(args)
80
 
sage/config.py CHANGED
@@ -71,6 +71,7 @@ def add_config_args(parser: ArgumentParser):
71
  args, _ = parser.parse_known_args()
72
  config_file = pkg_resources.resource_filename(__name__, f"configs/{args.mode}.yaml")
73
  parser.set_defaults(config=config_file)
 
74
 
75
 
76
  def add_repo_args(parser: ArgumentParser) -> Callable:
@@ -157,6 +158,14 @@ def add_vector_store_args(parser: ArgumentParser) -> Callable:
157
  help="When set to True, we rewrite the query 5 times, perform retrieval for each rewrite, and take the union "
158
  "of retrieved documents. See https://python.langchain.com/v0.1/docs/modules/data_connection/retrievers/MultiQueryRetriever/.",
159
  )
 
 
 
 
 
 
 
 
160
  return validate_vector_store_args
161
 
162
 
@@ -221,6 +230,25 @@ def add_llm_args(parser: ArgumentParser) -> Callable:
221
  return lambda _: True
222
 
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def validate_repo_args(args):
225
  """Validates the configuration of the repository."""
226
  if not re.match(r"^[^/]+/[^/]+$", args.repo_id):
@@ -233,7 +261,7 @@ def _validate_openai_embedding_args(args):
233
  raise ValueError("Please set the OPENAI_API_KEY environment variable.")
234
 
235
  if not args.embedding_model:
236
- args.embedding_model = "text-embedding-ada-002"
237
 
238
  if args.embedding_model not in OPENAI_DEFAULT_EMBEDDING_SIZE.keys():
239
  raise ValueError(f"Unrecognized embeddings.model={args.embedding_model}")
@@ -347,6 +375,19 @@ def validate_embedding_args(args):
347
 
348
  def validate_vector_store_args(args):
349
  """Validates the configuration of the vector store and sets defaults."""
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  if not args.index_namespace:
352
  # Attempt to derive a default index namespace from the repository information.
 
71
  args, _ = parser.parse_known_args()
72
  config_file = pkg_resources.resource_filename(__name__, f"configs/{args.mode}.yaml")
73
  parser.set_defaults(config=config_file)
74
+ return lambda _: True
75
 
76
 
77
  def add_repo_args(parser: ArgumentParser) -> Callable:
 
158
  help="When set to True, we rewrite the query 5 times, perform retrieval for each rewrite, and take the union "
159
  "of retrieved documents. See https://python.langchain.com/v0.1/docs/modules/data_connection/retrievers/MultiQueryRetriever/.",
160
  )
161
+ parser.add(
162
+ "--llm-retriever",
163
+ action=argparse.BooleanOptionalAction,
164
+ default=False,
165
+ help="When set to True, we use an LLM for retrieval: we pass the repository file hierarchy together with the "
166
+ "user query and ask the LLM to choose relevant files solely based on their paths. No indexing will be done, so "
167
+ "all the vector store / embedding arguments will be ignored.",
168
+ )
169
  return validate_vector_store_args
170
 
171
 
 
230
  return lambda _: True
231
 
232
 
233
+ def add_all_args(parser: ArgumentParser) -> Callable:
234
+ """Adds all arguments to the parser and returns a validator."""
235
+ arg_validators = [
236
+ add_config_args(parser),
237
+ add_repo_args(parser),
238
+ add_embedding_args(parser),
239
+ add_vector_store_args(parser),
240
+ add_reranking_args(parser),
241
+ add_indexing_args(parser),
242
+ add_llm_args(parser),
243
+ ]
244
+
245
+ def validate_all(args):
246
+ for validator in arg_validators:
247
+ validator(args)
248
+
249
+ return validate_all
250
+
251
+
252
  def validate_repo_args(args):
253
  """Validates the configuration of the repository."""
254
  if not re.match(r"^[^/]+/[^/]+$", args.repo_id):
 
261
  raise ValueError("Please set the OPENAI_API_KEY environment variable.")
262
 
263
  if not args.embedding_model:
264
+ args.embedding_model = "text-embedding-3-small"
265
 
266
  if args.embedding_model not in OPENAI_DEFAULT_EMBEDDING_SIZE.keys():
267
  raise ValueError(f"Unrecognized embeddings.model={args.embedding_model}")
 
375
 
376
  def validate_vector_store_args(args):
377
  """Validates the configuration of the vector store and sets defaults."""
378
+ if args.llm_retriever:
379
+ if not os.getenv("ANTHROPIC_API_KEY"):
380
+ raise ValueError(
381
+ "Please set the ANTHROPIC_API_KEY environment variable to use the LLM retriever. "
382
+ "(We're constrained to Claude because we need prompt caching.)"
383
+ )
384
+
385
+ if args.index_issues:
386
+ # The LLM retriever only makes sense on the code repository, since it passes file paths to the LLM.
387
+ raise ValueError("Cannot use --index-issues with --llm-retriever.")
388
+
389
+ # When using an LLM retriever, all the vector store arguments are ignored.
390
+ return
391
 
392
  if not args.index_namespace:
393
  # Attempt to derive a default index namespace from the repository information.
sage/data_manager.py CHANGED
@@ -175,15 +175,14 @@ class GitHubRepoManager(DataManager):
175
  )
176
  return True
177
 
178
- def walk(self) -> Generator[Tuple[Any, Dict], None, None]:
179
  """Walks the local repository path and yields a tuple of (content, metadata) for each file.
180
  The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").
181
 
182
  Args:
183
- included_extensions: Optional set of extensions to include.
184
- excluded_extensions: Optional set of extensions to exclude.
185
  """
186
- # We will keep apending to these files during the iteration, so we need to clear them first.
187
  repo_name = self.repo_id.replace("/", "_")
188
  included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt")
189
  excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt")
@@ -208,20 +207,49 @@ class GitHubRepoManager(DataManager):
208
  f.write(path + "\n")
209
 
210
  for file_path in included_file_paths:
 
 
 
 
 
 
 
 
 
 
211
  with open(file_path, "r") as f:
212
  try:
213
  contents = f.read()
214
  except UnicodeDecodeError:
215
  logging.warning("Unable to decode file %s. Skipping.", file_path)
216
  continue
217
- relative_file_path = file_path[len(self.local_dir) + 1 :]
218
- metadata = {
219
- "file_path": relative_file_path,
220
- "url": self.url_for_file(relative_file_path),
221
- }
222
  yield contents, metadata
223
 
224
  def url_for_file(self, file_path: str) -> str:
225
  """Converts a repository file path to a GitHub link."""
226
  file_path = file_path[len(self.repo_id) + 1 :]
227
  return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  )
176
  return True
177
 
178
+ def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
179
  """Walks the local repository path and yields a tuple of (content, metadata) for each file.
180
  The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").
181
 
182
  Args:
183
+ get_content: When set to True, yields (content, metadata) tuples. When set to False, yields metadata only.
 
184
  """
185
+ # We will keep appending to these files during the iteration, so we need to clear them first.
186
  repo_name = self.repo_id.replace("/", "_")
187
  included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt")
188
  excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt")
 
207
  f.write(path + "\n")
208
 
209
  for file_path in included_file_paths:
210
+ relative_file_path = file_path[len(self.local_dir) + 1 :]
211
+ metadata = {
212
+ "file_path": relative_file_path,
213
+ "url": self.url_for_file(relative_file_path),
214
+ }
215
+
216
+ if not get_content:
217
+ yield metadata
218
+ continue
219
+
220
  with open(file_path, "r") as f:
221
  try:
222
  contents = f.read()
223
  except UnicodeDecodeError:
224
  logging.warning("Unable to decode file %s. Skipping.", file_path)
225
  continue
 
 
 
 
 
226
  yield contents, metadata
227
 
228
  def url_for_file(self, file_path: str) -> str:
229
  """Converts a repository file path to a GitHub link."""
230
  file_path = file_path[len(self.repo_id) + 1 :]
231
  return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
232
+
233
+ def read_file(self, relative_file_path: str) -> str:
234
+ """Reads the content of the file at the given path."""
235
+ file_path = os.path.join(self.local_dir, relative_file_path)
236
+ with open(file_path, "r") as f:
237
+ return f.read()
238
+
239
+ def from_args(args: Dict):
240
+ """Creates a GitHubRepoManager from command-line arguments and clones the underlying repository."""
241
+ repo_manager = GitHubRepoManager(
242
+ repo_id=args.repo_id,
243
+ commit_hash=args.commit_hash,
244
+ access_token=os.getenv("GITHUB_TOKEN"),
245
+ local_dir=args.local_dir,
246
+ inclusion_file=args.include,
247
+ exclusion_file=args.exclude,
248
+ )
249
+ success = repo_manager.download()
250
+ if not success:
251
+ raise ValueError(
252
+ f"Unable to clone {args.repo_id}. Please check that it exists and you have access to it. "
253
+ "For private repositories, please set the GITHUB_TOKEN variable in your environment."
254
+ )
255
+ return repo_manager
sage/index.py CHANGED
@@ -36,6 +36,10 @@ def main():
36
  for validator in arg_validators:
37
  validator(args)
38
 
 
 
 
 
39
  # Additionally validate embedder and vector store compatibility.
40
  if args.embedding_provider == "openai" and args.vector_store_provider != "pinecone":
41
  parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
@@ -50,22 +54,7 @@ def main():
50
  repo_embedder = None
51
  if args.index_repo:
52
  logging.info("Cloning the repository...")
53
- repo_manager = GitHubRepoManager(
54
- args.repo_id,
55
- commit_hash=args.commit_hash,
56
- access_token=os.getenv("GITHUB_TOKEN"),
57
- local_dir=args.local_dir,
58
- inclusion_file=args.include,
59
- exclusion_file=args.exclude,
60
- )
61
-
62
- success = repo_manager.download()
63
- if not success:
64
- raise ValueError(
65
- f"Unable to clone {args.repo_id}. Please check that it exists and you have access to it. "
66
- "For private repositories, please set the GITHUB_TOKEN variable in your environment."
67
- )
68
-
69
  logging.info("Embedding the repo...")
70
  chunker = UniversalFileChunker(max_tokens=args.tokens_per_chunk)
71
  repo_embedder = build_batch_embedder_from_flags(repo_manager, chunker, args)
 
36
  for validator in arg_validators:
37
  validator(args)
38
 
39
+ if args.llm_retriever:
40
+ logging.warning("The LLM retriever does not require indexing, so this script is a no-op.")
41
+ return
42
+
43
  # Additionally validate embedder and vector store compatibility.
44
  if args.embedding_provider == "openai" and args.vector_store_provider != "pinecone":
45
  parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
 
54
  repo_embedder = None
55
  if args.index_repo:
56
  logging.info("Cloning the repository...")
57
+ repo_manager = GitHubRepoManager.from_args(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  logging.info("Embedding the repo...")
59
  chunker = UniversalFileChunker(max_tokens=args.tokens_per_chunk)
60
  repo_embedder = build_batch_embedder_from_flags(repo_manager, chunker, args)
sage/retriever.py CHANGED
@@ -1,32 +1,221 @@
1
- from typing import Optional
 
 
2
 
 
 
 
 
3
  from langchain.retrievers import ContextualCompressionRetriever
4
  from langchain.retrievers.multi_query import MultiQueryRetriever
 
 
5
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
6
  from langchain_openai import OpenAIEmbeddings
7
  from langchain_voyageai import VoyageAIEmbeddings
 
8
 
9
- from sage.data_manager import DataManager
10
  from sage.llm import build_llm_via_langchain
11
  from sage.reranker import build_reranker
12
  from sage.vector_store import build_vector_store_from_args
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
16
  """Builds a retriever (with optional reranking) from command-line arguments."""
17
-
18
- if args.embedding_provider == "openai":
19
- embeddings = OpenAIEmbeddings(model=args.embedding_model)
20
- elif args.embedding_provider == "voyage":
21
- embeddings = VoyageAIEmbeddings(model=args.embedding_model)
22
- elif args.embedding_provider == "gemini":
23
- embeddings = GoogleGenerativeAIEmbeddings(model=args.embedding_model)
24
  else:
25
- embeddings = None
 
 
 
 
 
 
 
26
 
27
- retriever = build_vector_store_from_args(args, data_manager).as_retriever(
28
- top_k=args.retriever_top_k, embeddings=embeddings, namespace=args.index_namespace
29
- )
30
 
31
  if args.multi_query_retriever:
32
  retriever = MultiQueryRetriever.from_llm(
 
1
+ import logging
2
+ import os
3
+ from typing import List, Optional
4
 
5
+ import anthropic
6
+ import Levenshtein
7
+ from anytree import Node, RenderTree
8
+ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
9
  from langchain.retrievers import ContextualCompressionRetriever
10
  from langchain.retrievers.multi_query import MultiQueryRetriever
11
+ from langchain.schema import BaseRetriever, Document
12
+ from langchain_core.output_parsers import CommaSeparatedListOutputParser
13
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
14
  from langchain_openai import OpenAIEmbeddings
15
  from langchain_voyageai import VoyageAIEmbeddings
16
+ from pydantic import Field
17
 
18
+ from sage.data_manager import DataManager, GitHubRepoManager
19
  from sage.llm import build_llm_via_langchain
20
  from sage.reranker import build_reranker
21
  from sage.vector_store import build_vector_store_from_args
22
 
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger()
25
+ logger.setLevel(logging.INFO)
26
+
27
+
28
+ class LLMRetriever(BaseRetriever):
29
+ """Custom Langchain retriever based on an LLM.
30
+
31
+ Builds a representation of the folder structure of the repo, feeds it to an LLM, and asks the LLM for the most
32
+ relevant files for a particular user query, expecting it to make decisions based solely on file names.
33
+
34
+ Only works with Claude/Anthropic, because it's very slow (e.g. 15s for a mid-sized codebase) and we need prompt
35
+ caching to make it usable.
36
+ """
37
+
38
+ repo_manager: GitHubRepoManager = Field(...)
39
+ top_k: int = Field(...)
40
+ all_repo_files: List[str] = Field(...)
41
+ repo_hierarchy: str = Field(...)
42
+
43
+ def __init__(self, repo_manager: GitHubRepoManager, top_k: int):
44
+ super().__init__()
45
+ self.repo_manager = repo_manager
46
+ self.top_k = top_k
47
+
48
+ # Best practice would be to make these fields @cached_property, but that impedes class serialization.
49
+ self.all_repo_files = [metadata["file_path"] for metadata in self.repo_manager.walk(get_content=False)]
50
+ self.repo_hierarchy = LLMRetriever._render_file_hierarchy(self.all_repo_files)
51
+
52
+ if not os.environ.get("ANTHROPIC_API_KEY"):
53
+ raise ValueError("Please set the ANTHROPIC_API_KEY environment variable for the LLMRetriever.")
54
+
55
+ def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
56
+ """Retrieve relevant documents for a given query."""
57
+ filenames = self._ask_llm_to_retrieve(user_query=query, top_k=self.top_k)
58
+ documents = []
59
+ for filename in filenames:
60
+ document = Document(
61
+ page_content=self.repo_manager.read_file(filename),
62
+ metadata={"file_path": filename, "url": self.repo_manager.url_for_file(filename)},
63
+ )
64
+ documents.append(document)
65
+ return documents
66
+
67
+ def _ask_llm_to_retrieve(self, user_query: str, top_k: int) -> List[str]:
68
+ """Feeds the file hierarchy and user query to the LLM and asks which files might be relevant."""
69
+ sys_prompt = f"""
70
+ You are a retriever system. You will be given a user query and a list of files in a GitHub repository. Your task is to determine the top {top_k} files that are most relevant to the user query.
71
+ DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
72
+
73
+ Here is the file hierarchy of the GitHub repository:
74
+
75
+ {self.repo_hierarchy}
76
+ """
77
+ # We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here.
78
+ augmented_user_query = f"""
79
+ User query: {user_query}
80
+
81
+ DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
82
+ """
83
+ response = LLMRetriever._call_via_anthropic_with_prompt_caching(sys_prompt, augmented_user_query)
84
+ files_from_llm = response.content[0].text.strip().split("\n")
85
+ validated_files = []
86
+
87
+ for filename in files_from_llm:
88
+ if filename not in self.all_repo_files:
89
+ if "/" not in filename:
90
+ # This is most likely some natural language excuse from the LLM; skip it.
91
+ continue
92
+ # Try a few heuristics to fix the filename.
93
+ filename = LLMRetriever._fix_filename(filename, self.repo_manager.repo_id)
94
+ if filename not in self.all_repo_files:
95
+ # The heuristics failed; try to find the closest filename in the repo.
96
+ filename = LLMRetriever._find_closest_filename(filename, self.all_repo_files)
97
+ if filename in self.all_repo_files:
98
+ validated_files.append(filename)
99
+ return validated_files
100
+
101
+ @staticmethod
102
+ def _call_via_anthropic_with_prompt_caching(system_prompt: str, user_prompt: str) -> str:
103
+ """Calls the Anthropic API with prompt caching for the system prompt.
104
+
105
+ See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching.
106
+
107
+ We're circumventing LangChain for now, because the feature is < 1 week old at the time of writing and has no
108
+ documentation: https://github.com/langchain-ai/langchain/pull/27087
109
+ """
110
+ CLAUDE_MODEL = "claude-3-5-sonnet-20240620"
111
+ client = anthropic.Anthropic()
112
+
113
+ system_message = {"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}
114
+ user_message = {"role": "user", "content": user_prompt}
115
+
116
+ response = client.beta.prompt_caching.messages.create(
117
+ model=CLAUDE_MODEL,
118
+ max_tokens=1024, # The maximum number of *output* tokens to generate.
119
+ system=[system_message],
120
+ messages=[user_message],
121
+ )
122
+ # Caching information will be under `cache_creation_input_tokens` and `cache_read_input_tokens`.
123
+ # Note that, for prompts shorter than 1024 tokens, Anthropic will not do any caching.
124
+ logging.info("Anthropic prompt caching info: %s", response.usage)
125
+ return response
126
+
127
+ @staticmethod
128
+ def _render_file_hierarchy(file_paths: List[str]) -> str:
129
+ """Given a list of files, produces a visualization of the file hierarchy. For instance:
130
+ folder1
131
+ folder11
132
+ file111.py
133
+ file112.py
134
+ folder12
135
+ file121.py
136
+ folder2
137
+ file21.py
138
+ """
139
+ # The "nodepath" is the path from root to the node (e.g. huggingface/transformers/examples)
140
+ nodepath_to_node = {}
141
+
142
+ for path in file_paths:
143
+ items = path.split("/")
144
+ nodepath = ""
145
+ parent_node = None
146
+ for item in items:
147
+ nodepath = f"{nodepath}/{item}"
148
+ if nodepath in nodepath_to_node:
149
+ node = nodepath_to_node[nodepath]
150
+ else:
151
+ node = Node(item, parent=parent_node)
152
+ nodepath_to_node[nodepath] = node
153
+ parent_node = node
154
+
155
+ root_path = f"/{file_paths[0].split('/')[0]}"
156
+ full_render = ""
157
+ root_node = nodepath_to_node[root_path]
158
+ for pre, fill, node in RenderTree(root_node):
159
+ render = "%s%s\n" % (pre, node.name)
160
+ # Replace special lines with empty strings to save on tokens.
161
+ render = render.replace("└", " ").replace("├", " ").replace("│", " ").replace("─", " ")
162
+ full_render += render
163
+ return full_render
164
+
165
+ @staticmethod
166
+ def _fix_filename(filename: str, repo_id: str) -> str:
167
+ """Attempts to "fix" a filename output by the LLM.
168
+
169
+ Common issues with LLM-generated filenames:
170
+ - The LLM prepends an extraneous "/".
171
+ - The LLM omits the name of the org (e.g. "transformers/README.md" instead of "huggingface/transformers/README.md").
172
+ - The LLM omits the name of the repo (e.g. "huggingface/README.md" instead of "huggingface/transformers/README.md").
173
+ - The LLM omits the org/repo prefix (e.g. "README.md" instead of "huggingface/transformers/README.md").
174
+ """
175
+ if filename.startswith("/"):
176
+ filename = filename[1:]
177
+ org_name, repo_name = repo_id.split("/")
178
+ items = filename.split("/")
179
+ if filename.startswith(org_name) and not filename.startswith(repo_id):
180
+ new_items = [org_name, repo_name] + items[1:]
181
+ return "/".join(new_items)
182
+ if not filename.startswith(org_name) and filename.startswith(repo_name):
183
+ return f"{org_name}/{filename}"
184
+ if not filename.startswith(org_name) and not filename.startswith(repo_name):
185
+ return f"{org_name}/{repo_name}/{filename}"
186
+ return filename
187
+
188
+ @staticmethod
189
+ def _find_closest_filename(filename: str, repo_filenames: List[str], max_edit_distance: int = 10) -> Optional[str]:
190
+ """Returns the path in the repo with smallest edit distance from `filename`. Helpful when the `filename` was
191
+ generated by an LLM and parts of it might have been hallucinated. Returns None if the closest path is more than
192
+ `max_edit_distance` away. In case of a tie, returns an arbitrary closest path.
193
+ """
194
+ distances = [(path, Levenshtein.distance(filename, path)) for path in repo_filenames]
195
+ distances.sort(key=lambda x: x[1])
196
+ if distances[0][1] <= max_edit_distance:
197
+ closest_path = distances[0][0]
198
+ return closest_path
199
+ return None
200
+
201
 
202
  def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
203
  """Builds a retriever (with optional reranking) from command-line arguments."""
204
+ if args.llm_retriever:
205
+ retriever = LLMRetriever(GitHubRepoManager.from_args(args), top_k=args.retriever_top_k)
 
 
 
 
 
206
  else:
207
+ if args.embedding_provider == "openai":
208
+ embeddings = OpenAIEmbeddings(model=args.embedding_model)
209
+ elif args.embedding_provider == "voyage":
210
+ embeddings = VoyageAIEmbeddings(model=args.embedding_model)
211
+ elif args.embedding_provider == "gemini":
212
+ embeddings = GoogleGenerativeAIEmbeddings(model=args.embedding_model)
213
+ else:
214
+ embeddings = None
215
 
216
+ retriever = build_vector_store_from_args(args, data_manager).as_retriever(
217
+ top_k=args.retriever_top_k, embeddings=embeddings, namespace=args.index_namespace
218
+ )
219
 
220
  if args.multi_query_retriever:
221
  retriever = MultiQueryRetriever.from_llm(