Spaces:
Running
Running
Add an LLM retriever (#70)
Browse files- README.md +7 -1
- benchmarks/retrieval/retrieve.py +5 -17
- pyproject.toml +2 -0
- sage/chat.py +2 -12
- sage/config.py +42 -1
- sage/data_manager.py +37 -9
- sage/index.py +5 -16
- sage/retriever.py +202 -13
README.md
CHANGED
|
@@ -225,10 +225,16 @@ Currently, we support the following types of retrieval:
|
|
| 225 |
- Note this is not available when running locally, only when using Pinecone as a vector store.
|
| 226 |
- Contrary to [Anthropic's findings](https://www.anthropic.com/news/contextual-retrieval), we find that BM25 is actually damaging performance *on codebases*, because it gives undeserved advantage to Markdown files.
|
| 227 |
|
| 228 |
-
- **Multi-query retrieval** performs multiple query rewrites, makes a separate retrieval call for each, and takes the union of the retrieved documents. You can activate it by passing `--multi-query-retrieval`.
|
| 229 |
|
| 230 |
- We find that [on our benchmark](benchmark/retrieval/README.md) this only marginally improves retrieval quality (from 0.44 to 0.46 R-precision) while being significantly slower and more expensive due to LLM calls. But your mileage may vary.
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
</details>
|
| 233 |
|
| 234 |
# Why chat with a codebase?
|
|
|
|
| 225 |
- Note this is not available when running locally, only when using Pinecone as a vector store.
|
| 226 |
- Contrary to [Anthropic's findings](https://www.anthropic.com/news/contextual-retrieval), we find that BM25 is actually damaging performance *on codebases*, because it gives undeserved advantage to Markdown files.
|
| 227 |
|
| 228 |
+
- **Multi-query retrieval** performs multiple query rewrites, makes a separate retrieval call for each, and takes the union of the retrieved documents. You can activate it by passing `--multi-query-retrieval`. This can be combined with both vanilla and hybrid RAG.
|
| 229 |
|
| 230 |
- We find that [on our benchmark](benchmark/retrieval/README.md) this only marginally improves retrieval quality (from 0.44 to 0.46 R-precision) while being significantly slower and more expensive due to LLM calls. But your mileage may vary.
|
| 231 |
|
| 232 |
+
- **LLM-only retrieval** completely circumvents indexing the codebase. We simply enumerate all file paths and pass them to an LLM together with the user query. We ask the LLM which files are likely to be relevant for the user query, solely based on their filenames. You can activate it by passing `--llm-retriever`.
|
| 233 |
+
|
| 234 |
+
- We find that [on our benchmark](benchmark/retrieval/README.md) the performance is comparable with vector database solutions (R-precision is 0.44 for both). This is quite remarkable, since we've saved so much effort by not indexing the codebase. However, we are reluctant to claim that these findings generalize, for the following reasons:
|
| 235 |
+
- Our (artificial) dataset occasionally contains explicit path names in the query, making it trivial for the LLM. Sample query: *"Alice is managing a series of machine learning experiments. Please explain in detail how `main` in `examples/pytorch/image-pretraining/run_mim.py` allows her to organize the outputs of each experiment in separate directories."*
|
| 236 |
+
- Our benchmark focuses on the Transformers library, which is well-maintained and the file paths are often meaningful. This might not be the case for all codebases.
|
| 237 |
+
|
| 238 |
</details>
|
| 239 |
|
| 240 |
# Why chat with a codebase?
|
benchmarks/retrieval/retrieve.py
CHANGED
|
@@ -41,24 +41,12 @@ def main():
|
|
| 41 |
)
|
| 42 |
|
| 43 |
parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
|
| 44 |
-
|
| 45 |
-
sage.config.
|
| 46 |
-
sage.config.add_embedding_args(parser)
|
| 47 |
-
sage.config.add_vector_store_args(parser)
|
| 48 |
-
sage.config.add_reranking_args(parser)
|
| 49 |
-
sage.config.add_repo_args(parser)
|
| 50 |
-
sage.config.add_indexing_args(parser)
|
| 51 |
args = parser.parse_args()
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
commit_hash=args.commit_hash,
|
| 56 |
-
access_token=os.getenv("GITHUB_TOKEN"),
|
| 57 |
-
local_dir=args.local_dir,
|
| 58 |
-
inclusion_file=args.include,
|
| 59 |
-
exclusion_file=args.exclude,
|
| 60 |
-
)
|
| 61 |
-
repo_manager.download()
|
| 62 |
retriever = build_retriever_from_args(args, repo_manager)
|
| 63 |
|
| 64 |
with open(args.benchmark, "r") as f:
|
|
|
|
| 41 |
)
|
| 42 |
|
| 43 |
parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
|
| 44 |
+
|
| 45 |
+
validator = sage.config.add_all_args(parser)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
args = parser.parse_args()
|
| 47 |
+
validator(args)
|
| 48 |
+
|
| 49 |
+
repo_manager = GitHubRepoManager.from_args(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
retriever = build_retriever_from_args(args, repo_manager)
|
| 51 |
|
| 52 |
with open(args.benchmark, "r") as f:
|
pyproject.toml
CHANGED
|
@@ -23,6 +23,7 @@ classifiers = [
|
|
| 23 |
dependencies = [
|
| 24 |
"GitPython==3.1.43",
|
| 25 |
"Pygments==2.18.0",
|
|
|
|
| 26 |
"cohere==5.9.2",
|
| 27 |
"configargparse",
|
| 28 |
"fastapi==0.112.2",
|
|
@@ -46,6 +47,7 @@ dependencies = [
|
|
| 46 |
"pinecone==5.0.1",
|
| 47 |
"pinecone-text==0.9.0",
|
| 48 |
"python-dotenv==1.0.1",
|
|
|
|
| 49 |
"requests==2.32.3",
|
| 50 |
"semchunk==2.2.0",
|
| 51 |
"sentence-transformers==3.1.0",
|
|
|
|
| 23 |
dependencies = [
|
| 24 |
"GitPython==3.1.43",
|
| 25 |
"Pygments==2.18.0",
|
| 26 |
+
"anytree==2.12.1",
|
| 27 |
"cohere==5.9.2",
|
| 28 |
"configargparse",
|
| 29 |
"fastapi==0.112.2",
|
|
|
|
| 47 |
"pinecone==5.0.1",
|
| 48 |
"pinecone-text==0.9.0",
|
| 49 |
"python-dotenv==1.0.1",
|
| 50 |
+
"python-Levenshtein==0.26.0",
|
| 51 |
"requests==2.32.3",
|
| 52 |
"semchunk==2.2.0",
|
| 53 |
"sentence-transformers==3.1.0",
|
sage/chat.py
CHANGED
|
@@ -71,20 +71,10 @@ def main():
|
|
| 71 |
default=False,
|
| 72 |
help="Whether to make the gradio app publicly accessible.",
|
| 73 |
)
|
| 74 |
-
sage_config.add_config_args(parser)
|
| 75 |
-
|
| 76 |
-
arg_validators = [
|
| 77 |
-
sage_config.add_repo_args(parser),
|
| 78 |
-
sage_config.add_embedding_args(parser),
|
| 79 |
-
sage_config.add_vector_store_args(parser),
|
| 80 |
-
sage_config.add_reranking_args(parser),
|
| 81 |
-
sage_config.add_llm_args(parser),
|
| 82 |
-
]
|
| 83 |
|
|
|
|
| 84 |
args = parser.parse_args()
|
| 85 |
-
|
| 86 |
-
for validator in arg_validators:
|
| 87 |
-
validator(args)
|
| 88 |
|
| 89 |
rag_chain = build_rag_chain(args)
|
| 90 |
|
|
|
|
| 71 |
default=False,
|
| 72 |
help="Whether to make the gradio app publicly accessible.",
|
| 73 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
validator = sage_config.add_all_args(parser)
|
| 76 |
args = parser.parse_args()
|
| 77 |
+
validator(args)
|
|
|
|
|
|
|
| 78 |
|
| 79 |
rag_chain = build_rag_chain(args)
|
| 80 |
|
sage/config.py
CHANGED
|
@@ -71,6 +71,7 @@ def add_config_args(parser: ArgumentParser):
|
|
| 71 |
args, _ = parser.parse_known_args()
|
| 72 |
config_file = pkg_resources.resource_filename(__name__, f"configs/{args.mode}.yaml")
|
| 73 |
parser.set_defaults(config=config_file)
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
def add_repo_args(parser: ArgumentParser) -> Callable:
|
|
@@ -157,6 +158,14 @@ def add_vector_store_args(parser: ArgumentParser) -> Callable:
|
|
| 157 |
help="When set to True, we rewrite the query 5 times, perform retrieval for each rewrite, and take the union "
|
| 158 |
"of retrieved documents. See https://python.langchain.com/v0.1/docs/modules/data_connection/retrievers/MultiQueryRetriever/.",
|
| 159 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
return validate_vector_store_args
|
| 161 |
|
| 162 |
|
|
@@ -221,6 +230,25 @@ def add_llm_args(parser: ArgumentParser) -> Callable:
|
|
| 221 |
return lambda _: True
|
| 222 |
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
def validate_repo_args(args):
|
| 225 |
"""Validates the configuration of the repository."""
|
| 226 |
if not re.match(r"^[^/]+/[^/]+$", args.repo_id):
|
|
@@ -233,7 +261,7 @@ def _validate_openai_embedding_args(args):
|
|
| 233 |
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
|
| 234 |
|
| 235 |
if not args.embedding_model:
|
| 236 |
-
args.embedding_model = "text-embedding-
|
| 237 |
|
| 238 |
if args.embedding_model not in OPENAI_DEFAULT_EMBEDDING_SIZE.keys():
|
| 239 |
raise ValueError(f"Unrecognized embeddings.model={args.embedding_model}")
|
|
@@ -347,6 +375,19 @@ def validate_embedding_args(args):
|
|
| 347 |
|
| 348 |
def validate_vector_store_args(args):
|
| 349 |
"""Validates the configuration of the vector store and sets defaults."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
if not args.index_namespace:
|
| 352 |
# Attempt to derive a default index namespace from the repository information.
|
|
|
|
| 71 |
args, _ = parser.parse_known_args()
|
| 72 |
config_file = pkg_resources.resource_filename(__name__, f"configs/{args.mode}.yaml")
|
| 73 |
parser.set_defaults(config=config_file)
|
| 74 |
+
return lambda _: True
|
| 75 |
|
| 76 |
|
| 77 |
def add_repo_args(parser: ArgumentParser) -> Callable:
|
|
|
|
| 158 |
help="When set to True, we rewrite the query 5 times, perform retrieval for each rewrite, and take the union "
|
| 159 |
"of retrieved documents. See https://python.langchain.com/v0.1/docs/modules/data_connection/retrievers/MultiQueryRetriever/.",
|
| 160 |
)
|
| 161 |
+
parser.add(
|
| 162 |
+
"--llm-retriever",
|
| 163 |
+
action=argparse.BooleanOptionalAction,
|
| 164 |
+
default=False,
|
| 165 |
+
help="When set to True, we use an LLM for retrieval: we pass the repository file hierarchy together with the "
|
| 166 |
+
"user query and ask the LLM to choose relevant files solely based on their paths. No indexing will be done, so "
|
| 167 |
+
"all the vector store / embedding arguments will be ignored.",
|
| 168 |
+
)
|
| 169 |
return validate_vector_store_args
|
| 170 |
|
| 171 |
|
|
|
|
| 230 |
return lambda _: True
|
| 231 |
|
| 232 |
|
| 233 |
+
def add_all_args(parser: ArgumentParser) -> Callable:
|
| 234 |
+
"""Adds all arguments to the parser and returns a validator."""
|
| 235 |
+
arg_validators = [
|
| 236 |
+
add_config_args(parser),
|
| 237 |
+
add_repo_args(parser),
|
| 238 |
+
add_embedding_args(parser),
|
| 239 |
+
add_vector_store_args(parser),
|
| 240 |
+
add_reranking_args(parser),
|
| 241 |
+
add_indexing_args(parser),
|
| 242 |
+
add_llm_args(parser),
|
| 243 |
+
]
|
| 244 |
+
|
| 245 |
+
def validate_all(args):
|
| 246 |
+
for validator in arg_validators:
|
| 247 |
+
validator(args)
|
| 248 |
+
|
| 249 |
+
return validate_all
|
| 250 |
+
|
| 251 |
+
|
| 252 |
def validate_repo_args(args):
|
| 253 |
"""Validates the configuration of the repository."""
|
| 254 |
if not re.match(r"^[^/]+/[^/]+$", args.repo_id):
|
|
|
|
| 261 |
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
|
| 262 |
|
| 263 |
if not args.embedding_model:
|
| 264 |
+
args.embedding_model = "text-embedding-3-small"
|
| 265 |
|
| 266 |
if args.embedding_model not in OPENAI_DEFAULT_EMBEDDING_SIZE.keys():
|
| 267 |
raise ValueError(f"Unrecognized embeddings.model={args.embedding_model}")
|
|
|
|
| 375 |
|
| 376 |
def validate_vector_store_args(args):
|
| 377 |
"""Validates the configuration of the vector store and sets defaults."""
|
| 378 |
+
if args.llm_retriever:
|
| 379 |
+
if not os.getenv("ANTHROPIC_API_KEY"):
|
| 380 |
+
raise ValueError(
|
| 381 |
+
"Please set the ANTHROPIC_API_KEY environment variable to use the LLM retriever. "
|
| 382 |
+
"(We're constrained to Claude because we need prompt caching.)"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
if args.index_issues:
|
| 386 |
+
# The LLM retriever only makes sense on the code repository, since it passes file paths to the LLM.
|
| 387 |
+
raise ValueError("Cannot use --index-issues with --llm-retriever.")
|
| 388 |
+
|
| 389 |
+
# When using an LLM retriever, all the vector store arguments are ignored.
|
| 390 |
+
return
|
| 391 |
|
| 392 |
if not args.index_namespace:
|
| 393 |
# Attempt to derive a default index namespace from the repository information.
|
sage/data_manager.py
CHANGED
|
@@ -175,15 +175,14 @@ class GitHubRepoManager(DataManager):
|
|
| 175 |
)
|
| 176 |
return True
|
| 177 |
|
| 178 |
-
def walk(self) -> Generator[Tuple[Any, Dict], None, None]:
|
| 179 |
"""Walks the local repository path and yields a tuple of (content, metadata) for each file.
|
| 180 |
The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").
|
| 181 |
|
| 182 |
Args:
|
| 183 |
-
|
| 184 |
-
excluded_extensions: Optional set of extensions to exclude.
|
| 185 |
"""
|
| 186 |
-
# We will keep
|
| 187 |
repo_name = self.repo_id.replace("/", "_")
|
| 188 |
included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt")
|
| 189 |
excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt")
|
|
@@ -208,20 +207,49 @@ class GitHubRepoManager(DataManager):
|
|
| 208 |
f.write(path + "\n")
|
| 209 |
|
| 210 |
for file_path in included_file_paths:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
with open(file_path, "r") as f:
|
| 212 |
try:
|
| 213 |
contents = f.read()
|
| 214 |
except UnicodeDecodeError:
|
| 215 |
logging.warning("Unable to decode file %s. Skipping.", file_path)
|
| 216 |
continue
|
| 217 |
-
relative_file_path = file_path[len(self.local_dir) + 1 :]
|
| 218 |
-
metadata = {
|
| 219 |
-
"file_path": relative_file_path,
|
| 220 |
-
"url": self.url_for_file(relative_file_path),
|
| 221 |
-
}
|
| 222 |
yield contents, metadata
|
| 223 |
|
| 224 |
def url_for_file(self, file_path: str) -> str:
|
| 225 |
"""Converts a repository file path to a GitHub link."""
|
| 226 |
file_path = file_path[len(self.repo_id) + 1 :]
|
| 227 |
return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
)
|
| 176 |
return True
|
| 177 |
|
| 178 |
+
def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
|
| 179 |
"""Walks the local repository path and yields a tuple of (content, metadata) for each file.
|
| 180 |
The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").
|
| 181 |
|
| 182 |
Args:
|
| 183 |
+
get_content: When set to True, yields (content, metadata) tuples. When set to False, yields metadata only.
|
|
|
|
| 184 |
"""
|
| 185 |
+
# We will keep appending to these files during the iteration, so we need to clear them first.
|
| 186 |
repo_name = self.repo_id.replace("/", "_")
|
| 187 |
included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt")
|
| 188 |
excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt")
|
|
|
|
| 207 |
f.write(path + "\n")
|
| 208 |
|
| 209 |
for file_path in included_file_paths:
|
| 210 |
+
relative_file_path = file_path[len(self.local_dir) + 1 :]
|
| 211 |
+
metadata = {
|
| 212 |
+
"file_path": relative_file_path,
|
| 213 |
+
"url": self.url_for_file(relative_file_path),
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
if not get_content:
|
| 217 |
+
yield metadata
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
with open(file_path, "r") as f:
|
| 221 |
try:
|
| 222 |
contents = f.read()
|
| 223 |
except UnicodeDecodeError:
|
| 224 |
logging.warning("Unable to decode file %s. Skipping.", file_path)
|
| 225 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
yield contents, metadata
|
| 227 |
|
| 228 |
def url_for_file(self, file_path: str) -> str:
|
| 229 |
"""Converts a repository file path to a GitHub link."""
|
| 230 |
file_path = file_path[len(self.repo_id) + 1 :]
|
| 231 |
return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
|
| 232 |
+
|
| 233 |
+
def read_file(self, relative_file_path: str) -> str:
|
| 234 |
+
"""Reads the content of the file at the given path."""
|
| 235 |
+
file_path = os.path.join(self.local_dir, relative_file_path)
|
| 236 |
+
with open(file_path, "r") as f:
|
| 237 |
+
return f.read()
|
| 238 |
+
|
| 239 |
+
def from_args(args: Dict):
|
| 240 |
+
"""Creates a GitHubRepoManager from command-line arguments and clones the underlying repository."""
|
| 241 |
+
repo_manager = GitHubRepoManager(
|
| 242 |
+
repo_id=args.repo_id,
|
| 243 |
+
commit_hash=args.commit_hash,
|
| 244 |
+
access_token=os.getenv("GITHUB_TOKEN"),
|
| 245 |
+
local_dir=args.local_dir,
|
| 246 |
+
inclusion_file=args.include,
|
| 247 |
+
exclusion_file=args.exclude,
|
| 248 |
+
)
|
| 249 |
+
success = repo_manager.download()
|
| 250 |
+
if not success:
|
| 251 |
+
raise ValueError(
|
| 252 |
+
f"Unable to clone {args.repo_id}. Please check that it exists and you have access to it. "
|
| 253 |
+
"For private repositories, please set the GITHUB_TOKEN variable in your environment."
|
| 254 |
+
)
|
| 255 |
+
return repo_manager
|
sage/index.py
CHANGED
|
@@ -36,6 +36,10 @@ def main():
|
|
| 36 |
for validator in arg_validators:
|
| 37 |
validator(args)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
# Additionally validate embedder and vector store compatibility.
|
| 40 |
if args.embedding_provider == "openai" and args.vector_store_provider != "pinecone":
|
| 41 |
parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
|
|
@@ -50,22 +54,7 @@ def main():
|
|
| 50 |
repo_embedder = None
|
| 51 |
if args.index_repo:
|
| 52 |
logging.info("Cloning the repository...")
|
| 53 |
-
repo_manager = GitHubRepoManager(
|
| 54 |
-
args.repo_id,
|
| 55 |
-
commit_hash=args.commit_hash,
|
| 56 |
-
access_token=os.getenv("GITHUB_TOKEN"),
|
| 57 |
-
local_dir=args.local_dir,
|
| 58 |
-
inclusion_file=args.include,
|
| 59 |
-
exclusion_file=args.exclude,
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
success = repo_manager.download()
|
| 63 |
-
if not success:
|
| 64 |
-
raise ValueError(
|
| 65 |
-
f"Unable to clone {args.repo_id}. Please check that it exists and you have access to it. "
|
| 66 |
-
"For private repositories, please set the GITHUB_TOKEN variable in your environment."
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
logging.info("Embedding the repo...")
|
| 70 |
chunker = UniversalFileChunker(max_tokens=args.tokens_per_chunk)
|
| 71 |
repo_embedder = build_batch_embedder_from_flags(repo_manager, chunker, args)
|
|
|
|
| 36 |
for validator in arg_validators:
|
| 37 |
validator(args)
|
| 38 |
|
| 39 |
+
if args.llm_retriever:
|
| 40 |
+
logging.warning("The LLM retriever does not require indexing, so this script is a no-op.")
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
# Additionally validate embedder and vector store compatibility.
|
| 44 |
if args.embedding_provider == "openai" and args.vector_store_provider != "pinecone":
|
| 45 |
parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
|
|
|
|
| 54 |
repo_embedder = None
|
| 55 |
if args.index_repo:
|
| 56 |
logging.info("Cloning the repository...")
|
| 57 |
+
repo_manager = GitHubRepoManager.from_args(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
logging.info("Embedding the repo...")
|
| 59 |
chunker = UniversalFileChunker(max_tokens=args.tokens_per_chunk)
|
| 60 |
repo_embedder = build_batch_embedder_from_flags(repo_manager, chunker, args)
|
sage/retriever.py
CHANGED
|
@@ -1,32 +1,221 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 4 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
|
|
|
|
|
|
| 5 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 6 |
from langchain_openai import OpenAIEmbeddings
|
| 7 |
from langchain_voyageai import VoyageAIEmbeddings
|
|
|
|
| 8 |
|
| 9 |
-
from sage.data_manager import DataManager
|
| 10 |
from sage.llm import build_llm_via_langchain
|
| 11 |
from sage.reranker import build_reranker
|
| 12 |
from sage.vector_store import build_vector_store_from_args
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
|
| 16 |
"""Builds a retriever (with optional reranking) from command-line arguments."""
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
embeddings = OpenAIEmbeddings(model=args.embedding_model)
|
| 20 |
-
elif args.embedding_provider == "voyage":
|
| 21 |
-
embeddings = VoyageAIEmbeddings(model=args.embedding_model)
|
| 22 |
-
elif args.embedding_provider == "gemini":
|
| 23 |
-
embeddings = GoogleGenerativeAIEmbeddings(model=args.embedding_model)
|
| 24 |
else:
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
|
| 31 |
if args.multi_query_retriever:
|
| 32 |
retriever = MultiQueryRetriever.from_llm(
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Optional
|
| 4 |
|
| 5 |
+
import anthropic
|
| 6 |
+
import Levenshtein
|
| 7 |
+
from anytree import Node, RenderTree
|
| 8 |
+
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
| 9 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 10 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 11 |
+
from langchain.schema import BaseRetriever, Document
|
| 12 |
+
from langchain_core.output_parsers import CommaSeparatedListOutputParser
|
| 13 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 14 |
from langchain_openai import OpenAIEmbeddings
|
| 15 |
from langchain_voyageai import VoyageAIEmbeddings
|
| 16 |
+
from pydantic import Field
|
| 17 |
|
| 18 |
+
from sage.data_manager import DataManager, GitHubRepoManager
|
| 19 |
from sage.llm import build_llm_via_langchain
|
| 20 |
from sage.reranker import build_reranker
|
| 21 |
from sage.vector_store import build_vector_store_from_args
|
| 22 |
|
| 23 |
+
logging.basicConfig(level=logging.INFO)
|
| 24 |
+
logger = logging.getLogger()
|
| 25 |
+
logger.setLevel(logging.INFO)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LLMRetriever(BaseRetriever):
|
| 29 |
+
"""Custom Langchain retriever based on an LLM.
|
| 30 |
+
|
| 31 |
+
Builds a representation of the folder structure of the repo, feeds it to an LLM, and asks the LLM for the most
|
| 32 |
+
relevant files for a particular user query, expecting it to make decisions based solely on file names.
|
| 33 |
+
|
| 34 |
+
Only works with Claude/Anthropic, because it's very slow (e.g. 15s for a mid-sized codebase) and we need prompt
|
| 35 |
+
caching to make it usable.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
repo_manager: GitHubRepoManager = Field(...)
|
| 39 |
+
top_k: int = Field(...)
|
| 40 |
+
all_repo_files: List[str] = Field(...)
|
| 41 |
+
repo_hierarchy: str = Field(...)
|
| 42 |
+
|
| 43 |
+
def __init__(self, repo_manager: GitHubRepoManager, top_k: int):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.repo_manager = repo_manager
|
| 46 |
+
self.top_k = top_k
|
| 47 |
+
|
| 48 |
+
# Best practice would be to make these fields @cached_property, but that impedes class serialization.
|
| 49 |
+
self.all_repo_files = [metadata["file_path"] for metadata in self.repo_manager.walk(get_content=False)]
|
| 50 |
+
self.repo_hierarchy = LLMRetriever._render_file_hierarchy(self.all_repo_files)
|
| 51 |
+
|
| 52 |
+
if not os.environ.get("ANTHROPIC_API_KEY"):
|
| 53 |
+
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable for the LLMRetriever.")
|
| 54 |
+
|
| 55 |
+
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
|
| 56 |
+
"""Retrieve relevant documents for a given query."""
|
| 57 |
+
filenames = self._ask_llm_to_retrieve(user_query=query, top_k=self.top_k)
|
| 58 |
+
documents = []
|
| 59 |
+
for filename in filenames:
|
| 60 |
+
document = Document(
|
| 61 |
+
page_content=self.repo_manager.read_file(filename),
|
| 62 |
+
metadata={"file_path": filename, "url": self.repo_manager.url_for_file(filename)},
|
| 63 |
+
)
|
| 64 |
+
documents.append(document)
|
| 65 |
+
return documents
|
| 66 |
+
|
| 67 |
+
def _ask_llm_to_retrieve(self, user_query: str, top_k: int) -> List[str]:
|
| 68 |
+
"""Feeds the file hierarchy and user query to the LLM and asks which files might be relevant."""
|
| 69 |
+
sys_prompt = f"""
|
| 70 |
+
You are a retriever system. You will be given a user query and a list of files in a GitHub repository. Your task is to determine the top {top_k} files that are most relevant to the user query.
|
| 71 |
+
DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
|
| 72 |
+
|
| 73 |
+
Here is the file hierarchy of the GitHub repository:
|
| 74 |
+
|
| 75 |
+
{self.repo_hierarchy}
|
| 76 |
+
"""
|
| 77 |
+
# We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here.
|
| 78 |
+
augmented_user_query = f"""
|
| 79 |
+
User query: {user_query}
|
| 80 |
+
|
| 81 |
+
DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
|
| 82 |
+
"""
|
| 83 |
+
response = LLMRetriever._call_via_anthropic_with_prompt_caching(sys_prompt, augmented_user_query)
|
| 84 |
+
files_from_llm = response.content[0].text.strip().split("\n")
|
| 85 |
+
validated_files = []
|
| 86 |
+
|
| 87 |
+
for filename in files_from_llm:
|
| 88 |
+
if filename not in self.all_repo_files:
|
| 89 |
+
if "/" not in filename:
|
| 90 |
+
# This is most likely some natural language excuse from the LLM; skip it.
|
| 91 |
+
continue
|
| 92 |
+
# Try a few heuristics to fix the filename.
|
| 93 |
+
filename = LLMRetriever._fix_filename(filename, self.repo_manager.repo_id)
|
| 94 |
+
if filename not in self.all_repo_files:
|
| 95 |
+
# The heuristics failed; try to find the closest filename in the repo.
|
| 96 |
+
filename = LLMRetriever._find_closest_filename(filename, self.all_repo_files)
|
| 97 |
+
if filename in self.all_repo_files:
|
| 98 |
+
validated_files.append(filename)
|
| 99 |
+
return validated_files
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def _call_via_anthropic_with_prompt_caching(system_prompt: str, user_prompt: str) -> str:
|
| 103 |
+
"""Calls the Anthropic API with prompt caching for the system prompt.
|
| 104 |
+
|
| 105 |
+
See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching.
|
| 106 |
+
|
| 107 |
+
We're circumventing LangChain for now, because the feature is < 1 week old at the time of writing and has no
|
| 108 |
+
documentation: https://github.com/langchain-ai/langchain/pull/27087
|
| 109 |
+
"""
|
| 110 |
+
CLAUDE_MODEL = "claude-3-5-sonnet-20240620"
|
| 111 |
+
client = anthropic.Anthropic()
|
| 112 |
+
|
| 113 |
+
system_message = {"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}
|
| 114 |
+
user_message = {"role": "user", "content": user_prompt}
|
| 115 |
+
|
| 116 |
+
response = client.beta.prompt_caching.messages.create(
|
| 117 |
+
model=CLAUDE_MODEL,
|
| 118 |
+
max_tokens=1024, # The maximum number of *output* tokens to generate.
|
| 119 |
+
system=[system_message],
|
| 120 |
+
messages=[user_message],
|
| 121 |
+
)
|
| 122 |
+
# Caching information will be under `cache_creation_input_tokens` and `cache_read_input_tokens`.
|
| 123 |
+
# Note that, for prompts shorter than 1024 tokens, Anthropic will not do any caching.
|
| 124 |
+
logging.info("Anthropic prompt caching info: %s", response.usage)
|
| 125 |
+
return response
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def _render_file_hierarchy(file_paths: List[str]) -> str:
|
| 129 |
+
"""Given a list of files, produces a visualization of the file hierarchy. For instance:
|
| 130 |
+
folder1
|
| 131 |
+
folder11
|
| 132 |
+
file111.py
|
| 133 |
+
file112.py
|
| 134 |
+
folder12
|
| 135 |
+
file121.py
|
| 136 |
+
folder2
|
| 137 |
+
file21.py
|
| 138 |
+
"""
|
| 139 |
+
# The "nodepath" is the path from root to the node (e.g. huggingface/transformers/examples)
|
| 140 |
+
nodepath_to_node = {}
|
| 141 |
+
|
| 142 |
+
for path in file_paths:
|
| 143 |
+
items = path.split("/")
|
| 144 |
+
nodepath = ""
|
| 145 |
+
parent_node = None
|
| 146 |
+
for item in items:
|
| 147 |
+
nodepath = f"{nodepath}/{item}"
|
| 148 |
+
if nodepath in nodepath_to_node:
|
| 149 |
+
node = nodepath_to_node[nodepath]
|
| 150 |
+
else:
|
| 151 |
+
node = Node(item, parent=parent_node)
|
| 152 |
+
nodepath_to_node[nodepath] = node
|
| 153 |
+
parent_node = node
|
| 154 |
+
|
| 155 |
+
root_path = f"/{file_paths[0].split('/')[0]}"
|
| 156 |
+
full_render = ""
|
| 157 |
+
root_node = nodepath_to_node[root_path]
|
| 158 |
+
for pre, fill, node in RenderTree(root_node):
|
| 159 |
+
render = "%s%s\n" % (pre, node.name)
|
| 160 |
+
# Replace special lines with empty strings to save on tokens.
|
| 161 |
+
render = render.replace("└", " ").replace("├", " ").replace("│", " ").replace("─", " ")
|
| 162 |
+
full_render += render
|
| 163 |
+
return full_render
|
| 164 |
+
|
| 165 |
+
@staticmethod
|
| 166 |
+
def _fix_filename(filename: str, repo_id: str) -> str:
|
| 167 |
+
"""Attempts to "fix" a filename output by the LLM.
|
| 168 |
+
|
| 169 |
+
Common issues with LLM-generated filenames:
|
| 170 |
+
- The LLM prepends an extraneous "/".
|
| 171 |
+
- The LLM omits the name of the org (e.g. "transformers/README.md" instead of "huggingface/transformers/README.md").
|
| 172 |
+
- The LLM omits the name of the repo (e.g. "huggingface/README.md" instead of "huggingface/transformers/README.md").
|
| 173 |
+
- The LLM omits the org/repo prefix (e.g. "README.md" instead of "huggingface/transformers/README.md").
|
| 174 |
+
"""
|
| 175 |
+
if filename.startswith("/"):
|
| 176 |
+
filename = filename[1:]
|
| 177 |
+
org_name, repo_name = repo_id.split("/")
|
| 178 |
+
items = filename.split("/")
|
| 179 |
+
if filename.startswith(org_name) and not filename.startswith(repo_id):
|
| 180 |
+
new_items = [org_name, repo_name] + items[1:]
|
| 181 |
+
return "/".join(new_items)
|
| 182 |
+
if not filename.startswith(org_name) and filename.startswith(repo_name):
|
| 183 |
+
return f"{org_name}/{filename}"
|
| 184 |
+
if not filename.startswith(org_name) and not filename.startswith(repo_name):
|
| 185 |
+
return f"{org_name}/{repo_name}/{filename}"
|
| 186 |
+
return filename
|
| 187 |
+
|
| 188 |
+
@staticmethod
|
| 189 |
+
def _find_closest_filename(filename: str, repo_filenames: List[str], max_edit_distance: int = 10) -> Optional[str]:
|
| 190 |
+
"""Returns the path in the repo with smallest edit distance from `filename`. Helpful when the `filename` was
|
| 191 |
+
generated by an LLM and parts of it might have been hallucinated. Returns None if the closest path is more than
|
| 192 |
+
`max_edit_distance` away. In case of a tie, returns an arbitrary closest path.
|
| 193 |
+
"""
|
| 194 |
+
distances = [(path, Levenshtein.distance(filename, path)) for path in repo_filenames]
|
| 195 |
+
distances.sort(key=lambda x: x[1])
|
| 196 |
+
if distances[0][1] <= max_edit_distance:
|
| 197 |
+
closest_path = distances[0][0]
|
| 198 |
+
return closest_path
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
|
| 202 |
def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
|
| 203 |
"""Builds a retriever (with optional reranking) from command-line arguments."""
|
| 204 |
+
if args.llm_retriever:
|
| 205 |
+
retriever = LLMRetriever(GitHubRepoManager.from_args(args), top_k=args.retriever_top_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
else:
|
| 207 |
+
if args.embedding_provider == "openai":
|
| 208 |
+
embeddings = OpenAIEmbeddings(model=args.embedding_model)
|
| 209 |
+
elif args.embedding_provider == "voyage":
|
| 210 |
+
embeddings = VoyageAIEmbeddings(model=args.embedding_model)
|
| 211 |
+
elif args.embedding_provider == "gemini":
|
| 212 |
+
embeddings = GoogleGenerativeAIEmbeddings(model=args.embedding_model)
|
| 213 |
+
else:
|
| 214 |
+
embeddings = None
|
| 215 |
|
| 216 |
+
retriever = build_vector_store_from_args(args, data_manager).as_retriever(
|
| 217 |
+
top_k=args.retriever_top_k, embeddings=embeddings, namespace=args.index_namespace
|
| 218 |
+
)
|
| 219 |
|
| 220 |
if args.multi_query_retriever:
|
| 221 |
retriever = MultiQueryRetriever.from_llm(
|