Spaces:
Running
Running
Ensure the happy path works well (#94)
Browse files- sage/code_symbols.py +49 -0
- sage/configs/remote.yaml +4 -1
- sage/data_manager.py +11 -10
- sage/reranker.py +2 -64
- sage/retriever.py +154 -35
sage/code_symbols.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities to extract code symbols (class and method names) from code files."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import List, Tuple
|
| 5 |
+
|
| 6 |
+
from tree_sitter import Node
|
| 7 |
+
|
| 8 |
+
from sage.chunker import CodeFileChunker
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _extract_classes_and_methods(node: Node, acc: List[Tuple[str, str]], parent_class: str = None):
|
| 12 |
+
"""Extracts classes and methods from a tree-sitter node and places them in the `acc` accumulator."""
|
| 13 |
+
if node.type in ["class_definition", "class_declaration"]:
|
| 14 |
+
class_name_node = node.child_by_field_name("name")
|
| 15 |
+
if class_name_node:
|
| 16 |
+
class_name = class_name_node.text.decode("utf-8")
|
| 17 |
+
acc.append((class_name, None))
|
| 18 |
+
for child in node.children:
|
| 19 |
+
_extract_classes_and_methods(child, acc, class_name)
|
| 20 |
+
elif node.type in ["function_definition", "method_definition"]:
|
| 21 |
+
function_name_node = node.child_by_field_name("name")
|
| 22 |
+
if function_name_node:
|
| 23 |
+
acc.append((parent_class, function_name_node.text.decode("utf-8")))
|
| 24 |
+
# We're not going deeper into a method. This means we're missing nested functions.
|
| 25 |
+
else:
|
| 26 |
+
for child in node.children:
|
| 27 |
+
_extract_classes_and_methods(child, acc, parent_class)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_code_symbols(file_path: str, content: str) -> List[Tuple[str, str]]:
|
| 31 |
+
"""Extracts code symbols from a file.
|
| 32 |
+
|
| 33 |
+
Code symbols are tuples of the form (class_name, method_name). For classes, method_name is None. For methods
|
| 34 |
+
that do not belong to a class, class_name is None.
|
| 35 |
+
"""
|
| 36 |
+
if not CodeFileChunker.is_code_file(file_path):
|
| 37 |
+
return []
|
| 38 |
+
|
| 39 |
+
if not content:
|
| 40 |
+
return []
|
| 41 |
+
|
| 42 |
+
logging.info(f"Extracting code symbols from {file_path}")
|
| 43 |
+
tree = CodeFileChunker.parse_tree(file_path, content)
|
| 44 |
+
if not tree:
|
| 45 |
+
return []
|
| 46 |
+
|
| 47 |
+
classes_and_methods = []
|
| 48 |
+
_extract_classes_and_methods(tree.root_node, classes_and_methods)
|
| 49 |
+
return classes_and_methods
|
sage/configs/remote.yaml
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
llm-retriever: true
|
| 2 |
llm-provider: anthropic
|
| 3 |
-
reranker
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
# The settings below (embeddings and vector store) are only relevant when setting --no-llm-retriever
|
| 6 |
|
|
|
|
| 1 |
llm-retriever: true
|
| 2 |
llm-provider: anthropic
|
| 3 |
+
# Here we optimize for ease of setup, so we skip the reranker which would require an extra API key.
|
| 4 |
+
reranker-provider: none
|
| 5 |
+
# Since we skipped the reranker, we can't afford to feed the retriever with too many candidates.
|
| 6 |
+
retriever-top-k: 5
|
| 7 |
|
| 8 |
# The settings below (embeddings and vector store) are only relevant when setting --no-llm-retriever
|
| 9 |
|
sage/data_manager.py
CHANGED
|
@@ -217,12 +217,8 @@ class GitHubRepoManager(DataManager):
|
|
| 217 |
yield metadata
|
| 218 |
continue
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
contents = f.read()
|
| 223 |
-
except UnicodeDecodeError:
|
| 224 |
-
logging.warning("Unable to decode file %s. Skipping.", file_path)
|
| 225 |
-
continue
|
| 226 |
yield contents, metadata
|
| 227 |
|
| 228 |
def url_for_file(self, file_path: str) -> str:
|
|
@@ -231,10 +227,15 @@ class GitHubRepoManager(DataManager):
|
|
| 231 |
return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
|
| 232 |
|
| 233 |
def read_file(self, relative_file_path: str) -> str:
|
| 234 |
-
"""Reads the
|
| 235 |
-
|
| 236 |
-
with open(
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
def from_args(args: Dict):
|
| 240 |
"""Creates a GitHubRepoManager from command-line arguments and clones the underlying repository."""
|
|
|
|
| 217 |
yield metadata
|
| 218 |
continue
|
| 219 |
|
| 220 |
+
contents = self.read_file(relative_file_path)
|
| 221 |
+
if contents:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
yield contents, metadata
|
| 223 |
|
| 224 |
def url_for_file(self, file_path: str) -> str:
|
|
|
|
| 227 |
return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
|
| 228 |
|
| 229 |
def read_file(self, relative_file_path: str) -> str:
|
| 230 |
+
"""Reads the contents of a file in the repository."""
|
| 231 |
+
absolute_file_path = os.path.join(self.local_dir, relative_file_path)
|
| 232 |
+
with open(absolute_file_path, "r") as f:
|
| 233 |
+
try:
|
| 234 |
+
contents = f.read()
|
| 235 |
+
return contents
|
| 236 |
+
except UnicodeDecodeError:
|
| 237 |
+
logging.warning("Unable to decode file %s.", absolute_file_path)
|
| 238 |
+
return None
|
| 239 |
|
| 240 |
def from_args(args: Dict):
|
| 241 |
"""Creates a GitHubRepoManager from command-line arguments and clones the underlying repository."""
|
sage/reranker.py
CHANGED
|
@@ -1,21 +1,14 @@
|
|
| 1 |
-
import logging
|
| 2 |
import os
|
| 3 |
from enum import Enum
|
| 4 |
-
from typing import
|
| 5 |
|
| 6 |
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
| 7 |
from langchain_cohere import CohereRerank
|
| 8 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 9 |
from langchain_community.document_compressors import JinaRerank
|
| 10 |
-
from langchain_core.
|
| 11 |
-
from langchain_core.documents import BaseDocumentCompressor, Document
|
| 12 |
-
from langchain_core.language_models import BaseLanguageModel
|
| 13 |
-
from langchain_core.prompts import PromptTemplate
|
| 14 |
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
| 15 |
from langchain_voyageai import VoyageAIRerank
|
| 16 |
-
from pydantic import ConfigDict, Field
|
| 17 |
-
|
| 18 |
-
from sage.llm import build_llm_via_langchain
|
| 19 |
|
| 20 |
|
| 21 |
class RerankerProvider(Enum):
|
|
@@ -25,58 +18,6 @@ class RerankerProvider(Enum):
|
|
| 25 |
NVIDIA = "nvidia"
|
| 26 |
JINA = "jina"
|
| 27 |
VOYAGE = "voyage"
|
| 28 |
-
# Anthropic doesn't provide an explicit reranker; we simply prompt the LLM with the user query and the content of
|
| 29 |
-
# the top k documents.
|
| 30 |
-
ANTHROPIC = "anthropic"
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class LLMReranker(BaseDocumentCompressor):
|
| 34 |
-
"""Reranker that passes the user query and top N documents to a language model to order them.
|
| 35 |
-
|
| 36 |
-
Note that Langchain's RerankLLM does not support LLMs from Anthropic.
|
| 37 |
-
https://python.langchain.com/api_reference/community/document_compressors/langchain_community.document_compressors.rankllm_rerank.RankLLMRerank.html
|
| 38 |
-
Also, they rely on https://github.com/castorini/rank_llm, which doesn't run on Apple Silicon (M1/M2 chips).
|
| 39 |
-
"""
|
| 40 |
-
|
| 41 |
-
llm: BaseLanguageModel = Field(...)
|
| 42 |
-
top_k: int = Field(...)
|
| 43 |
-
|
| 44 |
-
model_config = ConfigDict(
|
| 45 |
-
arbitrary_types_allowed=True,
|
| 46 |
-
extra="forbid",
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
@property
|
| 50 |
-
def prompt(self):
|
| 51 |
-
return PromptTemplate.from_template(
|
| 52 |
-
"Given the following query: '{query}'\n\n"
|
| 53 |
-
"And these documents:\n\n{documents}\n\n"
|
| 54 |
-
"Rank the documents based on their relevance to the query. "
|
| 55 |
-
"Return only the document numbers in order of relevance, separated by commas. For example: 2,5,1,3,4. "
|
| 56 |
-
"Return absolutely nothing else."
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
def compress_documents(
|
| 60 |
-
self,
|
| 61 |
-
documents: List[Document],
|
| 62 |
-
query: str,
|
| 63 |
-
callbacks: Optional[Callbacks] = None,
|
| 64 |
-
) -> List[Document]:
|
| 65 |
-
if len(documents) <= self.top_k:
|
| 66 |
-
return documents
|
| 67 |
-
|
| 68 |
-
doc_texts = [f"Document {i+1}:\n{doc.page_content}\n" for i, doc in enumerate(documents)]
|
| 69 |
-
docs_str = "\n".join(doc_texts)
|
| 70 |
-
|
| 71 |
-
llm_input = self.prompt.format(query=query, documents=docs_str)
|
| 72 |
-
result = self.llm.predict(llm_input)
|
| 73 |
-
|
| 74 |
-
try:
|
| 75 |
-
ranked_indices = [int(idx) - 1 for idx in result.strip().split(",")][: self.top_k]
|
| 76 |
-
return [documents[i] for i in ranked_indices]
|
| 77 |
-
except ValueError:
|
| 78 |
-
logging.warning("Failed to parse reranker output. Returning original order. LLM responded with: %s", result)
|
| 79 |
-
return documents[: self.top_k]
|
| 80 |
|
| 81 |
|
| 82 |
def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor:
|
|
@@ -105,7 +46,4 @@ def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[i
|
|
| 105 |
raise ValueError("Please set the VOYAGE_API_KEY environment variable")
|
| 106 |
model = model or "rerank-1"
|
| 107 |
return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k)
|
| 108 |
-
if provider == RerankerProvider.ANTHROPIC.value:
|
| 109 |
-
llm = build_llm_via_langchain("anthropic", model)
|
| 110 |
-
return LLMReranker(llm=llm, top_k=1)
|
| 111 |
raise ValueError(f"Invalid reranker provider: {provider}")
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from enum import Enum
|
| 3 |
+
from typing import Optional
|
| 4 |
|
| 5 |
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
| 6 |
from langchain_cohere import CohereRerank
|
| 7 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 8 |
from langchain_community.document_compressors import JinaRerank
|
| 9 |
+
from langchain_core.documents import BaseDocumentCompressor
|
|
|
|
|
|
|
|
|
|
| 10 |
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
| 11 |
from langchain_voyageai import VoyageAIRerank
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class RerankerProvider(Enum):
|
|
|
|
| 18 |
NVIDIA = "nvidia"
|
| 19 |
JINA = "jina"
|
| 20 |
VOYAGE = "voyage"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor:
|
|
|
|
| 46 |
raise ValueError("Please set the VOYAGE_API_KEY environment variable")
|
| 47 |
model = model or "rerank-1"
|
| 48 |
return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k)
|
|
|
|
|
|
|
|
|
|
| 49 |
raise ValueError(f"Invalid reranker provider: {provider}")
|
sage/retriever.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
-
from typing import List, Optional
|
| 4 |
|
| 5 |
import anthropic
|
| 6 |
import Levenshtein
|
|
@@ -9,12 +9,12 @@ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
|
| 9 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 10 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 11 |
from langchain.schema import BaseRetriever, Document
|
| 12 |
-
from langchain_core.output_parsers import CommaSeparatedListOutputParser
|
| 13 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 14 |
from langchain_openai import OpenAIEmbeddings
|
| 15 |
from langchain_voyageai import VoyageAIEmbeddings
|
| 16 |
from pydantic import Field
|
| 17 |
|
|
|
|
| 18 |
from sage.data_manager import DataManager, GitHubRepoManager
|
| 19 |
from sage.llm import build_llm_via_langchain
|
| 20 |
from sage.reranker import build_reranker
|
|
@@ -24,6 +24,9 @@ logging.basicConfig(level=logging.INFO)
|
|
| 24 |
logger = logging.getLogger()
|
| 25 |
logger.setLevel(logging.INFO)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
class LLMRetriever(BaseRetriever):
|
| 29 |
"""Custom Langchain retriever based on an LLM.
|
|
@@ -37,21 +40,76 @@ class LLMRetriever(BaseRetriever):
|
|
| 37 |
|
| 38 |
repo_manager: GitHubRepoManager = Field(...)
|
| 39 |
top_k: int = Field(...)
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def __init__(self, repo_manager: GitHubRepoManager, top_k: int):
|
| 44 |
super().__init__()
|
| 45 |
self.repo_manager = repo_manager
|
| 46 |
self.top_k = top_k
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
if not os.environ.get("ANTHROPIC_API_KEY"):
|
| 53 |
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable for the LLMRetriever.")
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
|
| 56 |
"""Retrieve relevant documents for a given query."""
|
| 57 |
filenames = self._ask_llm_to_retrieve(user_query=query, top_k=self.top_k)
|
|
@@ -66,13 +124,26 @@ class LLMRetriever(BaseRetriever):
|
|
| 66 |
|
| 67 |
def _ask_llm_to_retrieve(self, user_query: str, top_k: int) -> List[str]:
|
| 68 |
"""Feeds the file hierarchy and user query to the LLM and asks which files might be relevant."""
|
|
|
|
| 69 |
sys_prompt = f"""
|
| 70 |
-
You are a retriever system. You will be given a user query and a list of files in a GitHub repository
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
|
| 72 |
|
| 73 |
-
Here is the file hierarchy of the GitHub repository:
|
| 74 |
|
| 75 |
-
{
|
| 76 |
"""
|
| 77 |
|
| 78 |
# We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here.
|
|
@@ -82,20 +153,21 @@ User query: {user_query}
|
|
| 82 |
DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
|
| 83 |
"""
|
| 84 |
response = LLMRetriever._call_via_anthropic_with_prompt_caching(sys_prompt, augmented_user_query)
|
|
|
|
| 85 |
files_from_llm = response.content[0].text.strip().split("\n")
|
| 86 |
validated_files = []
|
| 87 |
|
| 88 |
for filename in files_from_llm:
|
| 89 |
-
if filename not in self.
|
| 90 |
if "/" not in filename:
|
| 91 |
# This is most likely some natural language excuse from the LLM; skip it.
|
| 92 |
continue
|
| 93 |
# Try a few heuristics to fix the filename.
|
| 94 |
filename = LLMRetriever._fix_filename(filename, self.repo_manager.repo_id)
|
| 95 |
-
if filename not in self.
|
| 96 |
# The heuristics failed; try to find the closest filename in the repo.
|
| 97 |
-
filename = LLMRetriever._find_closest_filename(filename, self.
|
| 98 |
-
if filename in self.
|
| 99 |
validated_files.append(filename)
|
| 100 |
return validated_files
|
| 101 |
|
|
@@ -108,13 +180,10 @@ DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to r
|
|
| 108 |
We're circumventing LangChain for now, because the feature is < 1 week old at the time of writing and has no
|
| 109 |
documentation: https://github.com/langchain-ai/langchain/pull/27087
|
| 110 |
"""
|
| 111 |
-
CLAUDE_MODEL = "claude-3-5-sonnet-20240620"
|
| 112 |
-
client = anthropic.Anthropic()
|
| 113 |
-
|
| 114 |
system_message = {"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}
|
| 115 |
user_message = {"role": "user", "content": user_prompt}
|
| 116 |
|
| 117 |
-
response =
|
| 118 |
model=CLAUDE_MODEL,
|
| 119 |
max_tokens=1024, # The maximum number of *output* tokens to generate.
|
| 120 |
system=[system_message],
|
|
@@ -126,34 +195,66 @@ DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to r
|
|
| 126 |
return response
|
| 127 |
|
| 128 |
@staticmethod
|
| 129 |
-
def _render_file_hierarchy(
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
folder1
|
| 132 |
folder11
|
| 133 |
-
file111.
|
| 134 |
file112.py
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
folder12
|
| 136 |
file121.py
|
|
|
|
|
|
|
| 137 |
folder2
|
| 138 |
file21.py
|
| 139 |
"""
|
| 140 |
# The "nodepath" is the path from root to the node (e.g. huggingface/transformers/examples)
|
| 141 |
nodepath_to_node = {}
|
| 142 |
|
| 143 |
-
for
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
full_render = ""
|
| 158 |
root_node = nodepath_to_node[root_path]
|
| 159 |
for pre, fill, node in RenderTree(root_node):
|
|
@@ -200,6 +301,24 @@ DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to r
|
|
| 200 |
return None
|
| 201 |
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
|
| 204 |
"""Builds a retriever (with optional reranking) from command-line arguments."""
|
| 205 |
if args.llm_retriever:
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
|
| 5 |
import anthropic
|
| 6 |
import Levenshtein
|
|
|
|
| 9 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 10 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 11 |
from langchain.schema import BaseRetriever, Document
|
|
|
|
| 12 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 13 |
from langchain_openai import OpenAIEmbeddings
|
| 14 |
from langchain_voyageai import VoyageAIEmbeddings
|
| 15 |
from pydantic import Field
|
| 16 |
|
| 17 |
+
from sage.code_symbols import get_code_symbols
|
| 18 |
from sage.data_manager import DataManager, GitHubRepoManager
|
| 19 |
from sage.llm import build_llm_via_langchain
|
| 20 |
from sage.reranker import build_reranker
|
|
|
|
| 24 |
logger = logging.getLogger()
|
| 25 |
logger.setLevel(logging.INFO)
|
| 26 |
|
| 27 |
+
CLAUDE_MODEL = "claude-3-5-sonnet-20240620"
|
| 28 |
+
CLAUDE_MODEL_CONTEXT_SIZE = 200_000
|
| 29 |
+
|
| 30 |
|
| 31 |
class LLMRetriever(BaseRetriever):
|
| 32 |
"""Custom Langchain retriever based on an LLM.
|
|
|
|
| 40 |
|
| 41 |
repo_manager: GitHubRepoManager = Field(...)
|
| 42 |
top_k: int = Field(...)
|
| 43 |
+
|
| 44 |
+
cached_repo_metadata: List[Dict] = Field(...)
|
| 45 |
+
cached_repo_files: List[str] = Field(...)
|
| 46 |
+
cached_repo_hierarchy: str = Field(...)
|
| 47 |
|
| 48 |
def __init__(self, repo_manager: GitHubRepoManager, top_k: int):
|
| 49 |
super().__init__()
|
| 50 |
self.repo_manager = repo_manager
|
| 51 |
self.top_k = top_k
|
| 52 |
|
| 53 |
+
# We cached these fields manually because:
|
| 54 |
+
# 1. Pydantic doesn't work with functools's @cached_property.
|
| 55 |
+
# 2. We can't use Pydantic's @computed_field because these fields depend on each other.
|
| 56 |
+
# 3. We can't use functools's @lru_cache because LLMRetriever needs to be hashable.
|
| 57 |
+
self.cached_repo_metadata = None
|
| 58 |
+
self.cached_repo_files = None
|
| 59 |
+
self.cached_repo_hierarchy = None
|
| 60 |
|
| 61 |
if not os.environ.get("ANTHROPIC_API_KEY"):
|
| 62 |
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable for the LLMRetriever.")
|
| 63 |
|
| 64 |
+
@property
|
| 65 |
+
def repo_metadata(self):
|
| 66 |
+
if not self.cached_repo_metadata:
|
| 67 |
+
self.cached_repo_metadata = [metadata for metadata in self.repo_manager.walk(get_content=False)]
|
| 68 |
+
|
| 69 |
+
# Extracting code symbols takes quite a while, since we need to read each file from disk.
|
| 70 |
+
# As a compromise, we do it for small codebases only.
|
| 71 |
+
small_codebase = len(self.repo_files) <= 200
|
| 72 |
+
if small_codebase:
|
| 73 |
+
for metadata in self.cached_repo_metadata:
|
| 74 |
+
file_path = metadata["file_path"]
|
| 75 |
+
content = self.repo_manager.read_file(file_path)
|
| 76 |
+
metadata["code_symbols"] = get_code_symbols(file_path, content)
|
| 77 |
+
|
| 78 |
+
return self.cached_repo_metadata
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def repo_files(self):
|
| 82 |
+
if not self.cached_repo_files:
|
| 83 |
+
self.cached_repo_files = set(metadata["file_path"] for metadata in self.repo_metadata)
|
| 84 |
+
return self.cached_repo_files
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def repo_hierarchy(self):
|
| 88 |
+
"""Produces a string that describes the structure of the repository. Depending on how big the codebase is, it
|
| 89 |
+
might include class and method names."""
|
| 90 |
+
if self.cached_repo_hierarchy is None:
|
| 91 |
+
render = LLMRetriever._render_file_hierarchy(self.repo_metadata, include_classes=True, include_methods=True)
|
| 92 |
+
max_tokens = CLAUDE_MODEL_CONTEXT_SIZE - 50_000 # 50,000 tokens for other parts of the prompt.
|
| 93 |
+
client = anthropic.Anthropic()
|
| 94 |
+
if client.count_tokens(render) > max_tokens:
|
| 95 |
+
logging.info("File hierarchy is too large; excluding methods.")
|
| 96 |
+
render = LLMRetriever._render_file_hierarchy(
|
| 97 |
+
self.repo_metadata, include_classes=True, include_methods=False
|
| 98 |
+
)
|
| 99 |
+
if client.count_tokens(render) > max_tokens:
|
| 100 |
+
logging.info("File hierarchy is still too large; excluding classes.")
|
| 101 |
+
render = LLMRetriever._render_file_hierarchy(
|
| 102 |
+
self.repo_metadata, include_classes=False, include_methods=False
|
| 103 |
+
)
|
| 104 |
+
if client.count_tokens(render) > max_tokens:
|
| 105 |
+
logging.info("File hierarchy is still too large; truncating.")
|
| 106 |
+
tokenizer = anthropic.Tokenizer()
|
| 107 |
+
tokens = tokenizer.tokenize(render)[:max_tokens]
|
| 108 |
+
render = tokenizer.detokenize(tokens)
|
| 109 |
+
logging.info("Number of tokens in render hierarchy: %d", client.count_tokens(render))
|
| 110 |
+
self.cached_repo_hierarchy = render
|
| 111 |
+
return self.cached_repo_hierarchy
|
| 112 |
+
|
| 113 |
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
|
| 114 |
"""Retrieve relevant documents for a given query."""
|
| 115 |
filenames = self._ask_llm_to_retrieve(user_query=query, top_k=self.top_k)
|
|
|
|
| 124 |
|
| 125 |
def _ask_llm_to_retrieve(self, user_query: str, top_k: int) -> List[str]:
|
| 126 |
"""Feeds the file hierarchy and user query to the LLM and asks which files might be relevant."""
|
| 127 |
+
repo_hierarchy = str(self.repo_hierarchy)
|
| 128 |
sys_prompt = f"""
|
| 129 |
+
You are a retriever system. You will be given a user query and a list of files in a GitHub repository, together with the class names in each file.
|
| 130 |
+
|
| 131 |
+
For instance:
|
| 132 |
+
folder1
|
| 133 |
+
folder2
|
| 134 |
+
folder3
|
| 135 |
+
file123.py
|
| 136 |
+
ClassName1
|
| 137 |
+
ClassName2
|
| 138 |
+
ClassName3
|
| 139 |
+
means that there is a file with path folder1/folder2/folder3/file123.py, which contains classes ClassName1, ClassName2, and ClassName3.
|
| 140 |
+
|
| 141 |
+
Your task is to determine the top {top_k} files that are most relevant to the user query.
|
| 142 |
DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
|
| 143 |
|
| 144 |
+
Here is the file hierarchy of the GitHub repository, together with the class names in each file:
|
| 145 |
|
| 146 |
+
{repo_hierarchy}
|
| 147 |
"""
|
| 148 |
|
| 149 |
# We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here.
|
|
|
|
| 153 |
DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
|
| 154 |
"""
|
| 155 |
response = LLMRetriever._call_via_anthropic_with_prompt_caching(sys_prompt, augmented_user_query)
|
| 156 |
+
|
| 157 |
files_from_llm = response.content[0].text.strip().split("\n")
|
| 158 |
validated_files = []
|
| 159 |
|
| 160 |
for filename in files_from_llm:
|
| 161 |
+
if filename not in self.repo_files:
|
| 162 |
if "/" not in filename:
|
| 163 |
# This is most likely some natural language excuse from the LLM; skip it.
|
| 164 |
continue
|
| 165 |
# Try a few heuristics to fix the filename.
|
| 166 |
filename = LLMRetriever._fix_filename(filename, self.repo_manager.repo_id)
|
| 167 |
+
if filename not in self.repo_files:
|
| 168 |
# The heuristics failed; try to find the closest filename in the repo.
|
| 169 |
+
filename = LLMRetriever._find_closest_filename(filename, self.repo_files)
|
| 170 |
+
if filename in self.repo_files:
|
| 171 |
validated_files.append(filename)
|
| 172 |
return validated_files
|
| 173 |
|
|
|
|
| 180 |
We're circumventing LangChain for now, because the feature is < 1 week old at the time of writing and has no
|
| 181 |
documentation: https://github.com/langchain-ai/langchain/pull/27087
|
| 182 |
"""
|
|
|
|
|
|
|
|
|
|
| 183 |
system_message = {"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}
|
| 184 |
user_message = {"role": "user", "content": user_prompt}
|
| 185 |
|
| 186 |
+
response = anthropic.Anthropic().beta.prompt_caching.messages.create(
|
| 187 |
model=CLAUDE_MODEL,
|
| 188 |
max_tokens=1024, # The maximum number of *output* tokens to generate.
|
| 189 |
system=[system_message],
|
|
|
|
| 195 |
return response
|
| 196 |
|
| 197 |
@staticmethod
|
| 198 |
+
def _render_file_hierarchy(
|
| 199 |
+
repo_metadata: List[Dict], include_classes: bool = True, include_methods: bool = True
|
| 200 |
+
) -> str:
|
| 201 |
+
"""Given a list of files, produces a visualization of the file hierarchy. This hierarchy optionally includes
|
| 202 |
+
class and method names, if available.
|
| 203 |
+
|
| 204 |
+
For large codebases, including both classes and methods might exceed the token limit of the LLM. In that case,
|
| 205 |
+
try setting `include_methods=False` first. If that's still too long, try also setting `include_classes=False`.
|
| 206 |
+
|
| 207 |
+
As a point of reference, the Transformers library requires setting `include_methods=False` to fit within
|
| 208 |
+
Claude's 200k context.
|
| 209 |
+
|
| 210 |
+
Example:
|
| 211 |
folder1
|
| 212 |
folder11
|
| 213 |
+
file111.md
|
| 214 |
file112.py
|
| 215 |
+
ClassName1
|
| 216 |
+
method_name1
|
| 217 |
+
method_name2
|
| 218 |
+
method_name3
|
| 219 |
folder12
|
| 220 |
file121.py
|
| 221 |
+
ClassName2
|
| 222 |
+
ClassName3
|
| 223 |
folder2
|
| 224 |
file21.py
|
| 225 |
"""
|
| 226 |
# The "nodepath" is the path from root to the node (e.g. huggingface/transformers/examples)
|
| 227 |
nodepath_to_node = {}
|
| 228 |
|
| 229 |
+
for metadata in repo_metadata:
|
| 230 |
+
path = metadata["file_path"]
|
| 231 |
+
paths = [path]
|
| 232 |
+
|
| 233 |
+
if include_classes or include_methods:
|
| 234 |
+
# Add the code symbols to the path. For instance, "folder/myfile.py/ClassName/method_name".
|
| 235 |
+
for class_name, method_name in metadata.get("code_symbols", []):
|
| 236 |
+
if include_classes and class_name:
|
| 237 |
+
paths.append(path + "/" + class_name)
|
| 238 |
+
# We exclude private methods to save tokens.
|
| 239 |
+
if include_methods and method_name and not method_name.startswith("_"):
|
| 240 |
+
paths.append(
|
| 241 |
+
path + "/" + class_name + "/" + method_name if class_name else path + "/" + method_name
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
for path in paths:
|
| 245 |
+
items = path.split("/")
|
| 246 |
+
nodepath = ""
|
| 247 |
+
parent_node = None
|
| 248 |
+
for item in items:
|
| 249 |
+
nodepath = f"{nodepath}/{item}"
|
| 250 |
+
if nodepath in nodepath_to_node:
|
| 251 |
+
node = nodepath_to_node[nodepath]
|
| 252 |
+
else:
|
| 253 |
+
node = Node(item, parent=parent_node)
|
| 254 |
+
nodepath_to_node[nodepath] = node
|
| 255 |
+
parent_node = node
|
| 256 |
+
|
| 257 |
+
root_path = "/" + repo_metadata[0]["file_path"].split("/")[0]
|
| 258 |
full_render = ""
|
| 259 |
root_node = nodepath_to_node[root_path]
|
| 260 |
for pre, fill, node in RenderTree(root_node):
|
|
|
|
| 301 |
return None
|
| 302 |
|
| 303 |
|
| 304 |
+
class RerankerWithErrorHandling(BaseRetriever):
|
| 305 |
+
"""Wraps a `ContextualCompressionRetriever` to catch errors during inference.
|
| 306 |
+
|
| 307 |
+
In practice, we see occasional `requests.exceptions.ReadTimeout` from the NVIDIA reranker, which crash the entire
|
| 308 |
+
pipeline. This wrapper catches such exceptions by simply returning the documents in the original order.
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
def __init__(self, reranker: ContextualCompressionRetriever):
|
| 312 |
+
self.reranker = reranker
|
| 313 |
+
|
| 314 |
+
def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
|
| 315 |
+
try:
|
| 316 |
+
return self.reranker._get_relevant_documents(query, run_manager=run_manager)
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logging.error(f"Error in reranker; preserving original document order from retriever. {e}")
|
| 319 |
+
return self.reranker.base_retriever._get_relevant_documents(query, run_manager=run_manager)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
|
| 323 |
"""Builds a retriever (with optional reranking) from command-line arguments."""
|
| 324 |
if args.llm_retriever:
|