juliaturc commited on
Commit
2b64dc9
·
1 Parent(s): 21daf3d

Ensure the happy path works well (#94)

Browse files
sage/code_symbols.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities to extract code symbols (class and method names) from code files."""
2
+
3
+ import logging
4
+ from typing import List, Tuple
5
+
6
+ from tree_sitter import Node
7
+
8
+ from sage.chunker import CodeFileChunker
9
+
10
+
11
+ def _extract_classes_and_methods(node: Node, acc: List[Tuple[str, str]], parent_class: str = None):
12
+ """Extracts classes and methods from a tree-sitter node and places them in the `acc` accumulator."""
13
+ if node.type in ["class_definition", "class_declaration"]:
14
+ class_name_node = node.child_by_field_name("name")
15
+ if class_name_node:
16
+ class_name = class_name_node.text.decode("utf-8")
17
+ acc.append((class_name, None))
18
+ for child in node.children:
19
+ _extract_classes_and_methods(child, acc, class_name)
20
+ elif node.type in ["function_definition", "method_definition"]:
21
+ function_name_node = node.child_by_field_name("name")
22
+ if function_name_node:
23
+ acc.append((parent_class, function_name_node.text.decode("utf-8")))
24
+ # We're not going deeper into a method. This means we're missing nested functions.
25
+ else:
26
+ for child in node.children:
27
+ _extract_classes_and_methods(child, acc, parent_class)
28
+
29
+
30
+ def get_code_symbols(file_path: str, content: str) -> List[Tuple[str, str]]:
31
+ """Extracts code symbols from a file.
32
+
33
+ Code symbols are tuples of the form (class_name, method_name). For classes, method_name is None. For methods
34
+ that do not belong to a class, class_name is None.
35
+ """
36
+ if not CodeFileChunker.is_code_file(file_path):
37
+ return []
38
+
39
+ if not content:
40
+ return []
41
+
42
+ logging.info(f"Extracting code symbols from {file_path}")
43
+ tree = CodeFileChunker.parse_tree(file_path, content)
44
+ if not tree:
45
+ return []
46
+
47
+ classes_and_methods = []
48
+ _extract_classes_and_methods(tree.root_node, classes_and_methods)
49
+ return classes_and_methods
sage/configs/remote.yaml CHANGED
@@ -1,6 +1,9 @@
1
  llm-retriever: true
2
  llm-provider: anthropic
3
- reranker-provider: anthropic
 
 
 
4
 
5
  # The settings below (embeddings and vector store) are only relevant when setting --no-llm-retriever
6
 
 
1
  llm-retriever: true
2
  llm-provider: anthropic
3
+ # Here we optimize for ease of setup, so we skip the reranker which would require an extra API key.
4
+ reranker-provider: none
5
+ # Since we skipped the reranker, we can't afford to feed the retriever with too many candidates.
6
+ retriever-top-k: 5
7
 
8
  # The settings below (embeddings and vector store) are only relevant when setting --no-llm-retriever
9
 
sage/data_manager.py CHANGED
@@ -217,12 +217,8 @@ class GitHubRepoManager(DataManager):
217
  yield metadata
218
  continue
219
 
220
- with open(file_path, "r") as f:
221
- try:
222
- contents = f.read()
223
- except UnicodeDecodeError:
224
- logging.warning("Unable to decode file %s. Skipping.", file_path)
225
- continue
226
  yield contents, metadata
227
 
228
  def url_for_file(self, file_path: str) -> str:
@@ -231,10 +227,15 @@ class GitHubRepoManager(DataManager):
231
  return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
232
 
233
  def read_file(self, relative_file_path: str) -> str:
234
- """Reads the content of the file at the given path."""
235
- file_path = os.path.join(self.local_dir, relative_file_path)
236
- with open(file_path, "r") as f:
237
- return f.read()
 
 
 
 
 
238
 
239
  def from_args(args: Dict):
240
  """Creates a GitHubRepoManager from command-line arguments and clones the underlying repository."""
 
217
  yield metadata
218
  continue
219
 
220
+ contents = self.read_file(relative_file_path)
221
+ if contents:
 
 
 
 
222
  yield contents, metadata
223
 
224
  def url_for_file(self, file_path: str) -> str:
 
227
  return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
228
 
229
  def read_file(self, relative_file_path: str) -> str:
230
+ """Reads the contents of a file in the repository."""
231
+ absolute_file_path = os.path.join(self.local_dir, relative_file_path)
232
+ with open(absolute_file_path, "r") as f:
233
+ try:
234
+ contents = f.read()
235
+ return contents
236
+ except UnicodeDecodeError:
237
+ logging.warning("Unable to decode file %s.", absolute_file_path)
238
+ return None
239
 
240
  def from_args(args: Dict):
241
  """Creates a GitHubRepoManager from command-line arguments and clones the underlying repository."""
sage/reranker.py CHANGED
@@ -1,21 +1,14 @@
1
- import logging
2
  import os
3
  from enum import Enum
4
- from typing import List, Optional
5
 
6
  from langchain.retrievers.document_compressors import CrossEncoderReranker
7
  from langchain_cohere import CohereRerank
8
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
9
  from langchain_community.document_compressors import JinaRerank
10
- from langchain_core.callbacks.manager import Callbacks
11
- from langchain_core.documents import BaseDocumentCompressor, Document
12
- from langchain_core.language_models import BaseLanguageModel
13
- from langchain_core.prompts import PromptTemplate
14
  from langchain_nvidia_ai_endpoints import NVIDIARerank
15
  from langchain_voyageai import VoyageAIRerank
16
- from pydantic import ConfigDict, Field
17
-
18
- from sage.llm import build_llm_via_langchain
19
 
20
 
21
  class RerankerProvider(Enum):
@@ -25,58 +18,6 @@ class RerankerProvider(Enum):
25
  NVIDIA = "nvidia"
26
  JINA = "jina"
27
  VOYAGE = "voyage"
28
- # Anthropic doesn't provide an explicit reranker; we simply prompt the LLM with the user query and the content of
29
- # the top k documents.
30
- ANTHROPIC = "anthropic"
31
-
32
-
33
- class LLMReranker(BaseDocumentCompressor):
34
- """Reranker that passes the user query and top N documents to a language model to order them.
35
-
36
- Note that Langchain's RerankLLM does not support LLMs from Anthropic.
37
- https://python.langchain.com/api_reference/community/document_compressors/langchain_community.document_compressors.rankllm_rerank.RankLLMRerank.html
38
- Also, they rely on https://github.com/castorini/rank_llm, which doesn't run on Apple Silicon (M1/M2 chips).
39
- """
40
-
41
- llm: BaseLanguageModel = Field(...)
42
- top_k: int = Field(...)
43
-
44
- model_config = ConfigDict(
45
- arbitrary_types_allowed=True,
46
- extra="forbid",
47
- )
48
-
49
- @property
50
- def prompt(self):
51
- return PromptTemplate.from_template(
52
- "Given the following query: '{query}'\n\n"
53
- "And these documents:\n\n{documents}\n\n"
54
- "Rank the documents based on their relevance to the query. "
55
- "Return only the document numbers in order of relevance, separated by commas. For example: 2,5,1,3,4. "
56
- "Return absolutely nothing else."
57
- )
58
-
59
- def compress_documents(
60
- self,
61
- documents: List[Document],
62
- query: str,
63
- callbacks: Optional[Callbacks] = None,
64
- ) -> List[Document]:
65
- if len(documents) <= self.top_k:
66
- return documents
67
-
68
- doc_texts = [f"Document {i+1}:\n{doc.page_content}\n" for i, doc in enumerate(documents)]
69
- docs_str = "\n".join(doc_texts)
70
-
71
- llm_input = self.prompt.format(query=query, documents=docs_str)
72
- result = self.llm.predict(llm_input)
73
-
74
- try:
75
- ranked_indices = [int(idx) - 1 for idx in result.strip().split(",")][: self.top_k]
76
- return [documents[i] for i in ranked_indices]
77
- except ValueError:
78
- logging.warning("Failed to parse reranker output. Returning original order. LLM responded with: %s", result)
79
- return documents[: self.top_k]
80
 
81
 
82
  def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor:
@@ -105,7 +46,4 @@ def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[i
105
  raise ValueError("Please set the VOYAGE_API_KEY environment variable")
106
  model = model or "rerank-1"
107
  return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k)
108
- if provider == RerankerProvider.ANTHROPIC.value:
109
- llm = build_llm_via_langchain("anthropic", model)
110
- return LLMReranker(llm=llm, top_k=1)
111
  raise ValueError(f"Invalid reranker provider: {provider}")
 
 
1
  import os
2
  from enum import Enum
3
+ from typing import Optional
4
 
5
  from langchain.retrievers.document_compressors import CrossEncoderReranker
6
  from langchain_cohere import CohereRerank
7
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
8
  from langchain_community.document_compressors import JinaRerank
9
+ from langchain_core.documents import BaseDocumentCompressor
 
 
 
10
  from langchain_nvidia_ai_endpoints import NVIDIARerank
11
  from langchain_voyageai import VoyageAIRerank
 
 
 
12
 
13
 
14
  class RerankerProvider(Enum):
 
18
  NVIDIA = "nvidia"
19
  JINA = "jina"
20
  VOYAGE = "voyage"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor:
 
46
  raise ValueError("Please set the VOYAGE_API_KEY environment variable")
47
  model = model or "rerank-1"
48
  return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k)
 
 
 
49
  raise ValueError(f"Invalid reranker provider: {provider}")
sage/retriever.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
  import os
3
- from typing import List, Optional
4
 
5
  import anthropic
6
  import Levenshtein
@@ -9,12 +9,12 @@ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
9
  from langchain.retrievers import ContextualCompressionRetriever
10
  from langchain.retrievers.multi_query import MultiQueryRetriever
11
  from langchain.schema import BaseRetriever, Document
12
- from langchain_core.output_parsers import CommaSeparatedListOutputParser
13
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
14
  from langchain_openai import OpenAIEmbeddings
15
  from langchain_voyageai import VoyageAIEmbeddings
16
  from pydantic import Field
17
 
 
18
  from sage.data_manager import DataManager, GitHubRepoManager
19
  from sage.llm import build_llm_via_langchain
20
  from sage.reranker import build_reranker
@@ -24,6 +24,9 @@ logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger()
25
  logger.setLevel(logging.INFO)
26
 
 
 
 
27
 
28
  class LLMRetriever(BaseRetriever):
29
  """Custom Langchain retriever based on an LLM.
@@ -37,21 +40,76 @@ class LLMRetriever(BaseRetriever):
37
 
38
  repo_manager: GitHubRepoManager = Field(...)
39
  top_k: int = Field(...)
40
- all_repo_files: List[str] = Field(...)
41
- repo_hierarchy: str = Field(...)
 
 
42
 
43
  def __init__(self, repo_manager: GitHubRepoManager, top_k: int):
44
  super().__init__()
45
  self.repo_manager = repo_manager
46
  self.top_k = top_k
47
 
48
- # Best practice would be to make these fields @cached_property, but that impedes class serialization.
49
- self.all_repo_files = [metadata["file_path"] for metadata in self.repo_manager.walk(get_content=False)]
50
- self.repo_hierarchy = LLMRetriever._render_file_hierarchy(self.all_repo_files)
 
 
 
 
51
 
52
  if not os.environ.get("ANTHROPIC_API_KEY"):
53
  raise ValueError("Please set the ANTHROPIC_API_KEY environment variable for the LLMRetriever.")
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
56
  """Retrieve relevant documents for a given query."""
57
  filenames = self._ask_llm_to_retrieve(user_query=query, top_k=self.top_k)
@@ -66,13 +124,26 @@ class LLMRetriever(BaseRetriever):
66
 
67
  def _ask_llm_to_retrieve(self, user_query: str, top_k: int) -> List[str]:
68
  """Feeds the file hierarchy and user query to the LLM and asks which files might be relevant."""
 
69
  sys_prompt = f"""
70
- You are a retriever system. You will be given a user query and a list of files in a GitHub repository. Your task is to determine the top {top_k} files that are most relevant to the user query.
 
 
 
 
 
 
 
 
 
 
 
 
71
  DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
72
 
73
- Here is the file hierarchy of the GitHub repository:
74
 
75
- {self.repo_hierarchy}
76
  """
77
 
78
  # We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here.
@@ -82,20 +153,21 @@ User query: {user_query}
82
  DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
83
  """
84
  response = LLMRetriever._call_via_anthropic_with_prompt_caching(sys_prompt, augmented_user_query)
 
85
  files_from_llm = response.content[0].text.strip().split("\n")
86
  validated_files = []
87
 
88
  for filename in files_from_llm:
89
- if filename not in self.all_repo_files:
90
  if "/" not in filename:
91
  # This is most likely some natural language excuse from the LLM; skip it.
92
  continue
93
  # Try a few heuristics to fix the filename.
94
  filename = LLMRetriever._fix_filename(filename, self.repo_manager.repo_id)
95
- if filename not in self.all_repo_files:
96
  # The heuristics failed; try to find the closest filename in the repo.
97
- filename = LLMRetriever._find_closest_filename(filename, self.all_repo_files)
98
- if filename in self.all_repo_files:
99
  validated_files.append(filename)
100
  return validated_files
101
 
@@ -108,13 +180,10 @@ DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to r
108
  We're circumventing LangChain for now, because the feature is < 1 week old at the time of writing and has no
109
  documentation: https://github.com/langchain-ai/langchain/pull/27087
110
  """
111
- CLAUDE_MODEL = "claude-3-5-sonnet-20240620"
112
- client = anthropic.Anthropic()
113
-
114
  system_message = {"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}
115
  user_message = {"role": "user", "content": user_prompt}
116
 
117
- response = client.beta.prompt_caching.messages.create(
118
  model=CLAUDE_MODEL,
119
  max_tokens=1024, # The maximum number of *output* tokens to generate.
120
  system=[system_message],
@@ -126,34 +195,66 @@ DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to r
126
  return response
127
 
128
  @staticmethod
129
- def _render_file_hierarchy(file_paths: List[str]) -> str:
130
- """Given a list of files, produces a visualization of the file hierarchy. For instance:
 
 
 
 
 
 
 
 
 
 
 
131
  folder1
132
  folder11
133
- file111.py
134
  file112.py
 
 
 
 
135
  folder12
136
  file121.py
 
 
137
  folder2
138
  file21.py
139
  """
140
  # The "nodepath" is the path from root to the node (e.g. huggingface/transformers/examples)
141
  nodepath_to_node = {}
142
 
143
- for path in file_paths:
144
- items = path.split("/")
145
- nodepath = ""
146
- parent_node = None
147
- for item in items:
148
- nodepath = f"{nodepath}/{item}"
149
- if nodepath in nodepath_to_node:
150
- node = nodepath_to_node[nodepath]
151
- else:
152
- node = Node(item, parent=parent_node)
153
- nodepath_to_node[nodepath] = node
154
- parent_node = node
155
-
156
- root_path = f"/{file_paths[0].split('/')[0]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  full_render = ""
158
  root_node = nodepath_to_node[root_path]
159
  for pre, fill, node in RenderTree(root_node):
@@ -200,6 +301,24 @@ DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to r
200
  return None
201
 
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
204
  """Builds a retriever (with optional reranking) from command-line arguments."""
205
  if args.llm_retriever:
 
1
  import logging
2
  import os
3
+ from typing import Dict, List, Optional
4
 
5
  import anthropic
6
  import Levenshtein
 
9
  from langchain.retrievers import ContextualCompressionRetriever
10
  from langchain.retrievers.multi_query import MultiQueryRetriever
11
  from langchain.schema import BaseRetriever, Document
 
12
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
13
  from langchain_openai import OpenAIEmbeddings
14
  from langchain_voyageai import VoyageAIEmbeddings
15
  from pydantic import Field
16
 
17
+ from sage.code_symbols import get_code_symbols
18
  from sage.data_manager import DataManager, GitHubRepoManager
19
  from sage.llm import build_llm_via_langchain
20
  from sage.reranker import build_reranker
 
24
  logger = logging.getLogger()
25
  logger.setLevel(logging.INFO)
26
 
27
+ CLAUDE_MODEL = "claude-3-5-sonnet-20240620"
28
+ CLAUDE_MODEL_CONTEXT_SIZE = 200_000
29
+
30
 
31
  class LLMRetriever(BaseRetriever):
32
  """Custom Langchain retriever based on an LLM.
 
40
 
41
  repo_manager: GitHubRepoManager = Field(...)
42
  top_k: int = Field(...)
43
+
44
+ cached_repo_metadata: List[Dict] = Field(...)
45
+ cached_repo_files: List[str] = Field(...)
46
+ cached_repo_hierarchy: str = Field(...)
47
 
48
  def __init__(self, repo_manager: GitHubRepoManager, top_k: int):
49
  super().__init__()
50
  self.repo_manager = repo_manager
51
  self.top_k = top_k
52
 
53
+ # We cached these fields manually because:
54
+ # 1. Pydantic doesn't work with functools's @cached_property.
55
+ # 2. We can't use Pydantic's @computed_field because these fields depend on each other.
56
+ # 3. We can't use functools's @lru_cache because LLMRetriever needs to be hashable.
57
+ self.cached_repo_metadata = None
58
+ self.cached_repo_files = None
59
+ self.cached_repo_hierarchy = None
60
 
61
  if not os.environ.get("ANTHROPIC_API_KEY"):
62
  raise ValueError("Please set the ANTHROPIC_API_KEY environment variable for the LLMRetriever.")
63
 
64
+ @property
65
+ def repo_metadata(self):
66
+ if not self.cached_repo_metadata:
67
+ self.cached_repo_metadata = [metadata for metadata in self.repo_manager.walk(get_content=False)]
68
+
69
+ # Extracting code symbols takes quite a while, since we need to read each file from disk.
70
+ # As a compromise, we do it for small codebases only.
71
+ small_codebase = len(self.repo_files) <= 200
72
+ if small_codebase:
73
+ for metadata in self.cached_repo_metadata:
74
+ file_path = metadata["file_path"]
75
+ content = self.repo_manager.read_file(file_path)
76
+ metadata["code_symbols"] = get_code_symbols(file_path, content)
77
+
78
+ return self.cached_repo_metadata
79
+
80
+ @property
81
+ def repo_files(self):
82
+ if not self.cached_repo_files:
83
+ self.cached_repo_files = set(metadata["file_path"] for metadata in self.repo_metadata)
84
+ return self.cached_repo_files
85
+
86
+ @property
87
+ def repo_hierarchy(self):
88
+ """Produces a string that describes the structure of the repository. Depending on how big the codebase is, it
89
+ might include class and method names."""
90
+ if self.cached_repo_hierarchy is None:
91
+ render = LLMRetriever._render_file_hierarchy(self.repo_metadata, include_classes=True, include_methods=True)
92
+ max_tokens = CLAUDE_MODEL_CONTEXT_SIZE - 50_000 # 50,000 tokens for other parts of the prompt.
93
+ client = anthropic.Anthropic()
94
+ if client.count_tokens(render) > max_tokens:
95
+ logging.info("File hierarchy is too large; excluding methods.")
96
+ render = LLMRetriever._render_file_hierarchy(
97
+ self.repo_metadata, include_classes=True, include_methods=False
98
+ )
99
+ if client.count_tokens(render) > max_tokens:
100
+ logging.info("File hierarchy is still too large; excluding classes.")
101
+ render = LLMRetriever._render_file_hierarchy(
102
+ self.repo_metadata, include_classes=False, include_methods=False
103
+ )
104
+ if client.count_tokens(render) > max_tokens:
105
+ logging.info("File hierarchy is still too large; truncating.")
106
+ tokenizer = anthropic.Tokenizer()
107
+ tokens = tokenizer.tokenize(render)[:max_tokens]
108
+ render = tokenizer.detokenize(tokens)
109
+ logging.info("Number of tokens in render hierarchy: %d", client.count_tokens(render))
110
+ self.cached_repo_hierarchy = render
111
+ return self.cached_repo_hierarchy
112
+
113
  def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
114
  """Retrieve relevant documents for a given query."""
115
  filenames = self._ask_llm_to_retrieve(user_query=query, top_k=self.top_k)
 
124
 
125
  def _ask_llm_to_retrieve(self, user_query: str, top_k: int) -> List[str]:
126
  """Feeds the file hierarchy and user query to the LLM and asks which files might be relevant."""
127
+ repo_hierarchy = str(self.repo_hierarchy)
128
  sys_prompt = f"""
129
+ You are a retriever system. You will be given a user query and a list of files in a GitHub repository, together with the class names in each file.
130
+
131
+ For instance:
132
+ folder1
133
+ folder2
134
+ folder3
135
+ file123.py
136
+ ClassName1
137
+ ClassName2
138
+ ClassName3
139
+ means that there is a file with path folder1/folder2/folder3/file123.py, which contains classes ClassName1, ClassName2, and ClassName3.
140
+
141
+ Your task is to determine the top {top_k} files that are most relevant to the user query.
142
  DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
143
 
144
+ Here is the file hierarchy of the GitHub repository, together with the class names in each file:
145
 
146
+ {repo_hierarchy}
147
  """
148
 
149
  # We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here.
 
153
  DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
154
  """
155
  response = LLMRetriever._call_via_anthropic_with_prompt_caching(sys_prompt, augmented_user_query)
156
+
157
  files_from_llm = response.content[0].text.strip().split("\n")
158
  validated_files = []
159
 
160
  for filename in files_from_llm:
161
+ if filename not in self.repo_files:
162
  if "/" not in filename:
163
  # This is most likely some natural language excuse from the LLM; skip it.
164
  continue
165
  # Try a few heuristics to fix the filename.
166
  filename = LLMRetriever._fix_filename(filename, self.repo_manager.repo_id)
167
+ if filename not in self.repo_files:
168
  # The heuristics failed; try to find the closest filename in the repo.
169
+ filename = LLMRetriever._find_closest_filename(filename, self.repo_files)
170
+ if filename in self.repo_files:
171
  validated_files.append(filename)
172
  return validated_files
173
 
 
180
  We're circumventing LangChain for now, because the feature is < 1 week old at the time of writing and has no
181
  documentation: https://github.com/langchain-ai/langchain/pull/27087
182
  """
 
 
 
183
  system_message = {"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}
184
  user_message = {"role": "user", "content": user_prompt}
185
 
186
+ response = anthropic.Anthropic().beta.prompt_caching.messages.create(
187
  model=CLAUDE_MODEL,
188
  max_tokens=1024, # The maximum number of *output* tokens to generate.
189
  system=[system_message],
 
195
  return response
196
 
197
  @staticmethod
198
+ def _render_file_hierarchy(
199
+ repo_metadata: List[Dict], include_classes: bool = True, include_methods: bool = True
200
+ ) -> str:
201
+ """Given a list of files, produces a visualization of the file hierarchy. This hierarchy optionally includes
202
+ class and method names, if available.
203
+
204
+ For large codebases, including both classes and methods might exceed the token limit of the LLM. In that case,
205
+ try setting `include_methods=False` first. If that's still too long, try also setting `include_classes=False`.
206
+
207
+ As a point of reference, the Transformers library requires setting `include_methods=False` to fit within
208
+ Claude's 200k context.
209
+
210
+ Example:
211
  folder1
212
  folder11
213
+ file111.md
214
  file112.py
215
+ ClassName1
216
+ method_name1
217
+ method_name2
218
+ method_name3
219
  folder12
220
  file121.py
221
+ ClassName2
222
+ ClassName3
223
  folder2
224
  file21.py
225
  """
226
  # The "nodepath" is the path from root to the node (e.g. huggingface/transformers/examples)
227
  nodepath_to_node = {}
228
 
229
+ for metadata in repo_metadata:
230
+ path = metadata["file_path"]
231
+ paths = [path]
232
+
233
+ if include_classes or include_methods:
234
+ # Add the code symbols to the path. For instance, "folder/myfile.py/ClassName/method_name".
235
+ for class_name, method_name in metadata.get("code_symbols", []):
236
+ if include_classes and class_name:
237
+ paths.append(path + "/" + class_name)
238
+ # We exclude private methods to save tokens.
239
+ if include_methods and method_name and not method_name.startswith("_"):
240
+ paths.append(
241
+ path + "/" + class_name + "/" + method_name if class_name else path + "/" + method_name
242
+ )
243
+
244
+ for path in paths:
245
+ items = path.split("/")
246
+ nodepath = ""
247
+ parent_node = None
248
+ for item in items:
249
+ nodepath = f"{nodepath}/{item}"
250
+ if nodepath in nodepath_to_node:
251
+ node = nodepath_to_node[nodepath]
252
+ else:
253
+ node = Node(item, parent=parent_node)
254
+ nodepath_to_node[nodepath] = node
255
+ parent_node = node
256
+
257
+ root_path = "/" + repo_metadata[0]["file_path"].split("/")[0]
258
  full_render = ""
259
  root_node = nodepath_to_node[root_path]
260
  for pre, fill, node in RenderTree(root_node):
 
301
  return None
302
 
303
 
304
+ class RerankerWithErrorHandling(BaseRetriever):
305
+ """Wraps a `ContextualCompressionRetriever` to catch errors during inference.
306
+
307
+ In practice, we see occasional `requests.exceptions.ReadTimeout` from the NVIDIA reranker, which crash the entire
308
+ pipeline. This wrapper catches such exceptions by simply returning the documents in the original order.
309
+ """
310
+
311
+ def __init__(self, reranker: ContextualCompressionRetriever):
312
+ self.reranker = reranker
313
+
314
+ def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
315
+ try:
316
+ return self.reranker._get_relevant_documents(query, run_manager=run_manager)
317
+ except Exception as e:
318
+ logging.error(f"Error in reranker; preserving original document order from retriever. {e}")
319
+ return self.reranker.base_retriever._get_relevant_documents(query, run_manager=run_manager)
320
+
321
+
322
  def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
323
  """Builds a retriever (with optional reranking) from command-line arguments."""
324
  if args.llm_retriever: