` pairs and select the top K scoring ones. The second stage is called *reranking*.
-
-
-
-While the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard) compares *open-source* embedding models based on their ability to rerank documents, we conducted experiments on the most popular *proprietary* APIs for reranking, including [NVIDIA](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html), [Voyage](https://docs.voyageai.com/docs/reranker), [Cohere](https://cohere.com/rerank) and [Jina](https://jina.ai/reranker/).
-
-#### Experiment settings
-- File chunks of <= 800 tokens;
-- Dense retriever using OpenAI's `text-embedding-3-small` model;
-- Retrieved `top_k=25` documents;
-- Reranked documents and selected `top_k=3`.
-
-#### Results
-- Across all evaluation metrics, the highest performing rerankers are, in this order: [NVIDIA](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html), [Voyage](https://docs.voyageai.com/docs/reranker), [Cohere](https://cohere.com/rerank) and [Jina](https://jina.ai/reranker/).
-- Not using a reranker at all completely tanks the performance.
-
-## Retrieval: Sparse vs Dense
-:classical_building: **Verdict**: Use fully dense embeddings.
-
-So far, we've been experimenting with purely *dense* retrieval. That is, documents are selected solely on the cosine distance between their embedding and the query embedding.
-
-Before the emergence of deep learning, retrievers used to be *sparse*. Such retrievers (e.g. [TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) or [BM25](https://en.wikipedia.org/wiki/Okapi_BM25)) were based on vectors of word counts (the vector of a document has the length of the dictionary, with each entry showing how many times a token occurs in the document; the term *sparse* comes from the fact that most entries are 0).
-
-Since sparse retrievers rely on exact string match, one might assume they come in handy when the query contains a relatively unique token (e.g. a class name) that occurs in a small number of documents.
-
-At the intersection of dense and sparse retrievers, *hybrid* retrievers score documents by the weighted average of the dense and sparse scores.
-
-
-
-In the experiment above, we compared the three types of retrievers (dense, hybrid and sparse).
-
-#### Experiment settings
-- File chunks of <= 800 tokens;
-- For the dense and hybrid retrievers, we used OpenAI's `text-embedding-3-small` model for embeddings;
-- Retrieved `top_k=25` documents;
-- Reranked documents using the [NVIDIA re-ranker](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html) and selected `top_k=3`.
-
-#### Results
-Somewhat surprisingly, sparse retrieval is actively hurting performance. The reason is that exact string matching will favor files that are in natural language (and therefore match the token distribution in the query).
-
-The plot below shows what percentage of the retrieved files are in Markdown. The purely sparse retriever chooses a Markdown file 40% of the time! Remember that we designed our questions so that the required context are Python files. This doesn't preclude Markdown files from actually being helpful in answering some of the questions, but surely not to this degree.
-
-
-
-## Chunk sizes
-:classical_building: **Verdict**: 800 tokens per chunk works well
-
-The [CodeRag paper](https://arxiv.org/pdf/2406.14497) suggests that the ideal chunk size is somewhere between 200-800 tokens. All our experiments above used 800 tokens per chunk. When experimenting with the other end of the spectrum, we saw very mild improvements from having smaller chunks. We believe that these marginal gains are not worth the increased indexing time (since we need to send 4x more queries to the batch embedding APIs).
-
-
diff --git a/benchmarks/retrieval/assets/chunks.png b/benchmarks/retrieval/assets/chunks.png
deleted file mode 100644
index 69ee271538912d70c5a603252a26aadc64823d6c..0000000000000000000000000000000000000000
Binary files a/benchmarks/retrieval/assets/chunks.png and /dev/null differ
diff --git a/benchmarks/retrieval/assets/embeddings.png b/benchmarks/retrieval/assets/embeddings.png
deleted file mode 100644
index 6cf1fa8d296e8c2c859b7ab2de738a663159ce66..0000000000000000000000000000000000000000
Binary files a/benchmarks/retrieval/assets/embeddings.png and /dev/null differ
diff --git a/benchmarks/retrieval/assets/markdown.png b/benchmarks/retrieval/assets/markdown.png
deleted file mode 100644
index 5d93ea2b2264c31f810b85e2ab73952590c43602..0000000000000000000000000000000000000000
Binary files a/benchmarks/retrieval/assets/markdown.png and /dev/null differ
diff --git a/benchmarks/retrieval/assets/rerankers.png b/benchmarks/retrieval/assets/rerankers.png
deleted file mode 100644
index e6db10c680a059b8982257114a6d2dc707f080c4..0000000000000000000000000000000000000000
Binary files a/benchmarks/retrieval/assets/rerankers.png and /dev/null differ
diff --git a/benchmarks/retrieval/assets/retrievers.png b/benchmarks/retrieval/assets/retrievers.png
deleted file mode 100644
index b693562dd1ab2410365daf509050fcf87a768b54..0000000000000000000000000000000000000000
Binary files a/benchmarks/retrieval/assets/retrievers.png and /dev/null differ
diff --git a/benchmarks/retrieval/requirements.txt b/benchmarks/retrieval/requirements.txt
deleted file mode 100644
index a6b70c870fd6d16a55039753042da2b38fff0ce7..0000000000000000000000000000000000000000
--- a/benchmarks/retrieval/requirements.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-dotenv
-ir_measures
diff --git a/benchmarks/retrieval/retrieve.py b/benchmarks/retrieval/retrieve.py
deleted file mode 100644
index cca2fbf2ae214ae048e92b36f38a497c6f116593..0000000000000000000000000000000000000000
--- a/benchmarks/retrieval/retrieve.py
+++ /dev/null
@@ -1,108 +0,0 @@
-"""Script to call retrieval on a benchmark dataset.
-
-Make sure to `pip install ir_measures` before running this script.
-"""
-
-import json
-import logging
-import os
-import time
-
-import configargparse
-from dotenv import load_dotenv
-from ir_measures import MAP, MRR, P, Qrel, R, Rprec, ScoredDoc, calc_aggregate, nDCG
-
-import sage.config
-from sage.data_manager import GitHubRepoManager
-from sage.retriever import build_retriever_from_args
-
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger()
-logger.setLevel(logging.INFO)
-
-load_dotenv()
-
-
-def main():
- parser = configargparse.ArgParser(
- description="Runs retrieval on a benchmark dataset.", ignore_unknown_config_file_keys=True
- )
- parser.add("--benchmark", required=True, help="Path to the benchmark dataset.")
- parser.add(
- "--gold-field", default="context_files", help="Field in the benchmark dataset that contains the golden answers."
- )
- parser.add(
- "--question-field", default="question", help="Field in the benchmark dataset that contains the questions."
- )
- parser.add(
- "--logs-dir",
- default=None,
- help="Path where to output predictions and metrics. Optional, since metrics are also printed to console.",
- )
-
- parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
-
- validator = sage.config.add_all_args(parser)
- args = parser.parse_args()
- validator(args)
-
- repo_manager = GitHubRepoManager.from_args(args)
- retriever = build_retriever_from_args(args, repo_manager)
-
- with open(args.benchmark, "r") as f:
- benchmark = json.load(f)
- if args.max_instances is not None:
- benchmark = benchmark[: args.max_instances]
-
- golden_docs = [] # List of ir_measures.Qrel objects
- retrieved_docs = [] # List of ir_measures.ScoredDoc objects
-
- for question_idx, item in enumerate(benchmark):
- print(f"Processing question {question_idx}...")
-
- query_id = str(question_idx) # Solely needed for ir_measures library.
-
- for golden_filepath in item[args.gold_field]:
- # All the file paths in the golden answer are equally relevant for the query (i.e. the order is irrelevant),
- # so we set relevance=1 for all of them.
- golden_docs.append(Qrel(query_id=query_id, doc_id=golden_filepath, relevance=1))
-
- # Make a retrieval call for the current question.
- retrieved = retriever.invoke(item[args.question_field])
- item["retrieved"] = []
- for doc_idx, doc in enumerate(retrieved):
- # The absolute value of the scores below does not affect the metrics; it merely determines the ranking of
- # the retrieved documents. The key of the score varies depending on the underlying retriever. If there's no
- # score, we use 1/(doc_idx+1) since it preserves the order of the documents.
- score = doc.metadata.get("score", doc.metadata.get("relevance_score", 1 / (doc_idx + 1)))
- retrieved_docs.append(ScoredDoc(query_id=query_id, doc_id=doc.metadata["file_path"], score=score))
- # Update the output dictionary with the retrieved documents.
- item["retrieved"].append({"file_path": doc.metadata["file_path"], "score": score})
-
- if "answer" in item:
- item.pop("answer") # Makes the output file harder to read.
-
- print("Calculating metrics...")
- results = calc_aggregate([Rprec, P @ 1, R @ 3, nDCG @ 3, MAP, MRR], golden_docs, retrieved_docs)
- results = {str(key): value for key, value in results.items()}
- if args.logs_dir:
- if not os.path.exists(args.logs_dir):
- os.makedirs(args.logs_dir)
-
- out_data = {
- "data": benchmark,
- "metrics": results,
- "flags": vars(args), # For reproducibility.
- }
-
- output_file = os.path.join(args.logs_dir, f"{time.time()}.json")
- with open(output_file, "w") as f:
- json.dump(out_data, f, indent=4)
-
- for key in sorted(results.keys()):
- print(f"{key}: {results[key]}")
- print(f"Predictions and metrics saved to {output_file}")
-
-
-if __name__ == "__main__":
- main()
diff --git a/benchmarks/retrieval/retrieve_kaggle.py b/benchmarks/retrieval/retrieve_kaggle.py
deleted file mode 100644
index 10e7af6fdc477897f887f65ca37ce1ecba15270d..0000000000000000000000000000000000000000
--- a/benchmarks/retrieval/retrieve_kaggle.py
+++ /dev/null
@@ -1,74 +0,0 @@
-"""Script to call retrieval on the Kaggle dataset.
-
-Steps:
-1. Make sure that your repository is already indexed. You can find instructions in the README for how to run the `sage-index` command.
-2. Download the test file from the Kaggle competition (https://www.kaggle.com/competitions/code-retrieval-for-hugging-face-transformers/data). You will pass the path to this file via the --benchmark flag below.
-3. Run this script:
-```
-# After you cloned the repository:
-cd sage
-pip install -e .
-
-# Run the actual retrieval script. Your flags may vary, but this is one example:
-python benchmarks/retrieval/retrieve_kaggle.py --benchmark=/path/to/kaggle/test/file.csv --mode=remote --pinecone-index-name=your-index --index-namespace=your-namespace
-```
-To see a full list of flags, checkout config.py (https://github.com/Storia-AI/sage/blob/main/sage/config.py).
-"""
-
-import csv
-import json
-import logging
-
-import configargparse
-
-import sage.config
-from sage.retriever import build_retriever_from_args
-
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger()
-logger.setLevel(logging.INFO)
-
-
-def main():
- parser = configargparse.ArgParser(
- description="Runs retrieval on the Kaggle dataset.", ignore_unknown_config_file_keys=True
- )
- parser.add("--benchmark", required=True, help="Path to the Kaggle dataset.")
- parser.add("--output-file", required=True, help="Path to the output file with predictions.")
-
- sage.config.add_config_args(parser)
- sage.config.add_llm_args(parser) # Necessary for --multi-query-retriever, which calls an LLM.
- sage.config.add_embedding_args(parser)
- sage.config.add_vector_store_args(parser)
- sage.config.add_reranking_args(parser)
- args = parser.parse_args()
- sage.config.validate_vector_store_args(args)
-
- retriever = build_retriever_from_args(args)
-
- with open(args.benchmark, "r") as f:
- benchmark = csv.DictReader(f)
- benchmark = [row for row in benchmark]
-
- outputs = []
- for question_idx, item in enumerate(benchmark):
- print(f"Processing question {question_idx}...")
-
- retrieved = retriever.invoke(item["question"])
- # Sort by score in descending order.
- retrieved = sorted(
- retrieved, key=lambda doc: doc.metadata.get("score", doc.metadata.get("relevance_score")), reverse=True
- )
- # Keep top 3, since the Kaggle competition only evaluates the top 3.
- retrieved = retrieved[:3]
- retrieved_filenames = [doc.metadata["file_path"] for doc in retrieved]
- outputs.append((item["id"], json.dumps(retrieved_filenames)))
-
- with open(args.output_file, "w") as f:
- csv_writer = csv.writer(f)
- csv_writer.writerow(["id", "documents"])
- csv_writer.writerows(outputs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/benchmarks/retrieval/sample.json b/benchmarks/retrieval/sample.json
deleted file mode 100644
index 01d2bb7d9715ff7758515b36e6d1fc5972d443e5..0000000000000000000000000000000000000000
--- a/benchmarks/retrieval/sample.json
+++ /dev/null
@@ -1,177 +0,0 @@
-[
- {
- "repo": "huggingface/transformers",
- "commit": "7bb1c99800d235791dace10305731f377db8077b",
- "context_files": [
- "huggingface/transformers/src/transformers/commands/serving.py",
- "huggingface/transformers/src/transformers/pipelines/__init__.py"
- ],
- "question": "With the introduction of a new translation service for \"en_to_es\", how does `serve_command_factory` ensure the server is prepared to handle this specific task efficiently?",
- "answer": "The `serve_command_factory` function is designed to dynamically configure and deploy a server that can handle a variety of tasks, including complex and parameterized tasks such as language-specific translations. When a new translation service for \"en_to_es\" is introduced, the function ensures efficient handling through several key steps and mechanisms:\n\n1. **Task Specification and Validation**: The function begins by accepting command line arguments (`args`) which specify the task among other parameters. The task, in this case \"translation_en_to_es\", is passed to the `pipeline` function. This function internally calls `check_task` to validate and process the task string. The `check_task` function checks if the task is recognized and supported, and it returns structured data that includes default configurations and any task-specific options necessary for the task.\n\n2. **Dynamic Pipeline Configuration**: The `pipeline` function uses the information returned by `check_task` to dynamically configure a pipeline suitable for the specified task. This includes loading the appropriate model, tokenizer, and any additional components like feature extractors or image processors that are required for the task. For translation tasks, this might involve loading a specific model that is optimized for the language pair \"en_to_es\".\n\n3. **Model and Tokenizer Loading**: Depending on the task and the provided arguments, the `pipeline` function loads the necessary model and tokenizer. If the model or tokenizer is not explicitly specified in the command line arguments, the function will load default models and tokenizers that are appropriate for the task. This ensures that the server is using the most suitable tools for the translation between English and Spanish.\n\n4. **Device and Performance Configuration**: The `pipeline` function also handles device allocation (e.g., CPU, GPU) based on the provided arguments, ensuring that the model runs on the optimal hardware. This is crucial for performance, especially when handling complex tasks like real-time translation.\n\n5. **Server Initialization**: Once the pipeline is configured, `serve_command_factory` initializes a `ServeCommand` with the configured pipeline and server parameters such as host, port, and number of workers. This command encapsulates all the necessary settings to run the server efficiently.\n\n6. **Adaptability and Scalability**: The design of `serve_command_factory` and the underlying `pipeline` function allows for easy adaptation to new tasks and changes in task requirements. Adding support for a new translation like \"en_to_es\" primarily involves ensuring that the `check_task` function recognizes the task and that appropriate models and tokenizers are available and correctly configured.\n\nIn summary, `serve_command_factory` ensures that the server is prepared to handle the \"en_to_es\" translation task efficiently by dynamically configuring the server based on the specified task, validating and processing the task to load appropriate models and tools, and setting up the server with optimal hardware and network settings. This process guarantees that the server can adapt to new and complex tasks with minimal manual intervention.",
- "expected_identifiers": [
- "args",
- "check_task",
- "pipeline",
- "ServeCommand"
- ]
- },
- {
- "repo": "huggingface/transformers",
- "commit": "7bb1c99800d235791dace10305731f377db8077b",
- "context_files": [
- "huggingface/transformers/src/transformers/models/esm/modeling_esmfold.py",
- "huggingface/transformers/src/transformers/models/esm/openfold_utils/feats.py",
- "huggingface/transformers/src/transformers/models/esm/openfold_utils/tensor_utils.py"
- ],
- "question": "In a high-throughput setting where multiple protein structures are processed simultaneously, how does `EsmForProteinFolding.output_to_pdb` ensure accurate and independent structural representation in the resulting PDB files?",
- "answer": "In a high-throughput setting where multiple protein structures are processed simultaneously, the function `output_to_pdb` ensures accurate and independent structural representation in the resulting PDB files through a combination of specialized tensor operations and careful indexing. This is achieved primarily through the use of the `atom14_to_atom37` function, which itself relies on the `batched_gather` function to correctly map atom positions from a simplified model output to a more detailed atomic representation.\n\n### Detailed Workflow:\n\n1. **Batch Processing and Tensor Operations**:\n - The `output_to_pdb` function begins by converting all tensor data to the CPU and converting them to NumPy arrays for easier manipulation. This step is crucial for performance and compatibility with subsequent operations that may not be optimized for GPU tensors.\n\n2. **Mapping Atom Positions**:\n - The function `atom14_to_atom37` is called within `output_to_pdb`. This function is responsible for expanding the reduced atom representation (14 atoms per amino acid) to a fuller representation (37 atoms per amino acid). It uses the `batched_gather` function to achieve this mapping accurately across potentially multiple proteins in a batch.\n\n3. **Complex Indexing with `batched_gather`**:\n - `batched_gather` plays a critical role in ensuring that the atom positions are mapped correctly. It constructs a complex indexing tuple that combines batch indices with the provided indices for gathering (`inds`). This tuple (`ranges`) includes both batch dimensions and the specific indices where atoms need to be gathered from the `atom14` tensor.\n - The use of `ranges` in `batched_gather` ensures that each protein's data is handled independently, preventing any cross-contamination or mixing of data between different proteins in the batch. This is crucial for maintaining the structural integrity of each protein.\n\n4. **Application of Mask and Final Adjustments**:\n - After mapping the positions, `atom14_to_atom37` applies a mask (`batch[\"atom37_atom_exists\"]`) to ensure that only existing atoms are considered. This step further ensures the accuracy of the structural data by zeroing out positions of non-existent atoms, preventing any erroneous data from affecting the structural representation.\n\n5. **Generation of PDB Data**:\n - Back in `output_to_pdb`, for each protein in the batch, an instance of `OFProtein` is created with the mapped atom positions, types, and other relevant data. The `to_pdb` function is then used to convert these protein data into the PDB format, ready for downstream applications like molecular dynamics simulations.\n\n### Conclusion:\n\nThrough the careful use of tensor operations, complex indexing, and data masking, `output_to_pdb` ensures that each protein's structural data is accurately and independently represented in the PDB outputs. This methodical approach is essential in high-throughput settings, where the accuracy and integrity of structural data are paramount for subsequent scientific analysis and applications.",
- "expected_identifiers": [
- "atom14_to_atom37",
- "batched_gather",
- "batch[\"atom37_atom_exists\"]",
- "OFProtein"
- ]
- },
- {
- "repo": "huggingface/transformers",
- "commit": "7bb1c99800d235791dace10305731f377db8077b",
- "context_files": [
- "huggingface/transformers/src/transformers/models/auto/auto_factory.py",
- "huggingface/transformers/src/transformers/dynamic_module_utils.py"
- ],
- "question": "Following a security update in the production environment that limits internet connectivity, how does `_BaseAutoModelClass.from_pretrained` guarantee that the loaded model adheres strictly to the predefined version and settings?",
- "answer": "In the updated production environment with restricted internet connectivity, `_BaseAutoModelClass.from_pretrained` ensures that the model loaded adheres strictly to the predefined version and settings through several key mechanisms, primarily involving the management of model files and code via a version control system and secure access to private repositories.\n\n### Version Control and Revision Specification\n\nThe function leverages a version control system that allows users to specify exact revisions of the model or code they wish to use. This is evident in the handling of the `revision` parameter in functions like `get_cached_module_file` and `get_class_from_dynamic_module`. The `revision` parameter can accept any identifier allowed by git, such as a branch name, a tag name, or a commit id. This ensures that the exact version of the model or code that was tested and approved in other environments (like development or staging) is the same version being deployed in production.\n\nFor example, in the `get_cached_module_file` function, the `revision` parameter is used to fetch the specific version of a module file from a repository:\n```python\nresolved_module_file = cached_file(\n pretrained_model_name_or_path,\n module_file,\n cache_dir=cache_dir,\n force_download=force_download,\n proxies=proxies,\n resume_download=resume_download,\n local_files_only=local_files_only,\n token=token,\n revision=revision,\n repo_type=repo_type,\n _commit_hash=_commit_hash,\n)\n```\n\n### Secure Access to Private Repositories\n\nThe function can authenticate access to private repositories using tokens, which is crucial when operating in environments with strict security protocols. The `token` parameter, which can be set to a string or `True` (to use the token generated by `huggingface-cli login`), is used to authenticate HTTP requests for remote files. This is handled securely in both `get_cached_module_file` and `get_class_from_dynamic_module`, ensuring that only authorized users can access private model files or code.\n\nFor instance, in `get_class_from_dynamic_module`, the `token` parameter is used to authenticate and download the necessary module file:\n```python\nfinal_module = get_cached_module_file(\n repo_id,\n module_file + \".py\",\n cache_dir=cache_dir,\n force_download=force_download,\n resume_download=resume_download,\n proxies=proxies,\n token=token,\n revision=code_revision,\n local_files_only=local_files_only,\n repo_type=repo_type,\n)\n```\n\n### Handling Restricted Internet Connectivity\n\nIn environments with limited internet access, the `local_files_only` parameter becomes particularly important. This parameter, when set to `True`, forces the function to only look for model files locally and not attempt to download them from the internet. This is crucial for ensuring that the model loading process does not fail due to lack of internet access and adheres to strict security protocols that might block external internet connections.\n\n### Conclusion\n\nBy utilizing these mechanisms, `_BaseAutoModelClass.from_pretrained` ensures that the model loaded in a production environment with restricted internet access is exactly the version specified, using secure and authenticated access where necessary. This approach guarantees consistency, reproducibility, and adherence to security protocols across different environments.",
- "expected_identifiers": [
- "revision",
- "token",
- "local_files_only"
- ]
- },
- {
- "repo": "huggingface/transformers",
- "commit": "7bb1c99800d235791dace10305731f377db8077b",
- "context_files": [
- "huggingface/transformers/src/transformers/models/auto/auto_factory.py",
- "huggingface/transformers/src/transformers/utils/doc.py"
- ],
- "question": "When developing a specialized model class in the Transformers library, how does `auto_class_update` ensure that the new class's methods are tailored specifically for its requirements while preserving the functionality of the original methods from the base class?",
- "answer": "In the Transformers library, the `auto_class_update` function plays a crucial role in dynamically creating specialized model classes that inherit functionalities from a base class but also have unique customizations. This is particularly important when different model classes need specific configurations or preprocessing steps that are not shared across all models.\n\nThe core mechanism that allows `auto_class_update` to achieve this functionality without altering the behavior of the base class methods lies in its use of the `copy_func` function. Here's how it works step-by-step:\n\n1. **Copying the Function**: `copy_func` is used to create an exact copy of the methods `from_config` and `from_pretrained` from the base class `_BaseAutoModelClass`. This is done by duplicating the `__code__` object of these methods. The `__code__` object contains the compiled executable code that the Python interpreter runs. By copying this code object, the new function retains the exact behavior and logic of the original function.\n\n2. **Customization of the Copied Function**: After copying, `auto_class_update` modifies the docstrings of these methods to tailor them to the specific subclass. This involves inserting a specific `head_doc`, replacing placeholders like `\"BaseAutoModelClass\"` with the subclass's name, and updating example checkpoints specific to the model type (e.g., `\"google-bert/bert-base-cased\"`). These modifications are crucial for providing accurate and relevant documentation and guidance specific to each subclass.\n\n3. **Re-assignment as Class Methods**: Once the functions are copied and customized, they are re-assigned to the subclass as class methods. This is done using `classmethod(from_config)` and `classmethod(from_pretrained)`. This step ensures that these methods, now tailored and documented specifically for the subclass, are callable on the subclass itself.\n\n4. **Preservation of Base Class Functionality**: Since the original methods are copied before being modified, the base class `_BaseAutoModelClass` retains its original `from_config` and `from_pretrained` methods without any changes. This isolation ensures that modifications specific to one subclass do not impact the behavior or documentation of these methods in the base class or any other subclasses.\n\nBy following this process, `auto_class_update` ensures that each subclass in the Transformers library can have methods that are specifically tailored to its requirements, both in terms of functionality and documentation, while preserving the integrity and functionality of the original methods from the base class. This approach enhances modularity and flexibility in the library, allowing developers to easily extend and customize model classes for various use cases.",
- "expected_identifiers": [
- "__code__",
- "copy_func",
- "from_config",
- "from_pretrained"
- ]
- },
- {
- "repo": "huggingface/transformers",
- "commit": "7bb1c99800d235791dace10305731f377db8077b",
- "context_files": [
- "huggingface/transformers/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py",
- "huggingface/transformers/src/transformers/modeling_utils.py"
- ],
- "question": "Given a system limitation of 5GB per file, how does `convert_checkpoint_from_megatron_to_transformers` manage the storage of a large model's data to comply with this restriction?",
- "answer": "The `convert_checkpoint_from_megatron_to_transformers` function manages the storage of a large model's data to comply with a system limitation of 5GB per file by utilizing the `shard_checkpoint` function to split the model's state dictionary into multiple sub-checkpoints, each of which does not exceed the specified maximum size.\n\nHere's a detailed breakdown of how this is achieved:\n\n1. **Sharding Process**: The `shard_checkpoint` function is called within `convert_checkpoint_from_megatron_to_transformers` to handle the division of the model's weights into smaller parts or shards. This function takes the entire state dictionary of the model (`output_state_dict`) and a maximum shard size as inputs.\n\n2. **Size Calculation**: The function calculates the byte size of each tensor in the state dictionary using the `dtype_byte_size` function. This function determines the number of bytes each element of a tensor occupies in memory, based on the tensor's data type (`dtype`). This calculation is crucial as it helps in accurately assessing how much space each tensor will take when saved as part of a shard.\n\n3. **Iterative Sharding**: The `shard_checkpoint` iterates through each tensor in the state dictionary and adds them to the current shard until adding another tensor would exceed the maximum shard size (5GB in this scenario). When this limit is reached, a new shard is started. This ensures that no individual shard file exceeds the specified size limit.\n\n4. **Handling Oversized Tensors**: If a single tensor is larger than the maximum shard size, it is placed in its own shard. This is a necessary exception to prevent the function from failing due to an inability to split a tensor.\n\n5. **Saving Shards**: Each shard is saved as a separate file. The naming convention and indexing ensure that each part of the model can be identified and accessed correctly. The function also generates an index file if the model is split into multiple shards, detailing where each parameter is stored.\n\n6. **Parameter Mapping**: The function maintains a mapping (`weight_map`) of model parameters to their respective shard files. This mapping is crucial for efficiently loading the model from its sharded state.\n\nBy following these steps, the `convert_checkpoint_from_megatron_to_transformers` function ensures that each shard of the converted model adheres to the 5GB file size limit imposed by the system. This methodical sharding allows for efficient storage and handling of large models without exceeding system file size limitations.",
- "expected_identifiers": [
- "shard_checkpoint",
- "dtype_byte_size",
- "output_state_dict",
- "weight_map"
- ]
- },
- {
- "repo": "huggingface/transformers",
- "commit": "7bb1c99800d235791dace10305731f377db8077b",
- "context_files": [
- "huggingface/transformers/src/transformers/quantizers/quantizer_hqq.py",
- "huggingface/transformers/src/transformers/integrations/hqq.py"
- ],
- "question": "In a scenario where a neural network model is being optimized for deployment, how does `HqqHfQuantizer._process_model_before_weight_loading` ensure that each linear module is appropriately and uniquely quantized?",
- "answer": "In the scenario where a neural network model is being optimized for deployment using the `HqqHfQuantizer._process_model_before_weight_loading` function, the process of ensuring that each linear module is appropriately and uniquely quantized involves several key steps and functions.\n\n1. **Tagging Modules with Unique Identifiers**: The process begins with the `get_linear_tags` function, which is responsible for identifying and tagging all linear modules within the model. This function uses a `set` to collect the names of these modules, which inherently ensures that each tag is unique (since sets do not allow duplicates). This is crucial because it prevents any confusion or errors in later stages when quantization parameters are applied to these tags.\n\n2. **Applying Quantization Configuration**: Once the linear modules are tagged, the `prepare_for_hqq_linear` function takes over. This function receives a `quantization_config` and a list of modules not to convert. It first calls `autoname_modules` to ensure each module in the model has a unique name, and then retrieves the linear tags using `get_linear_tags`. The function then filters these tags to exclude any specified in `skip_modules` or `modules_to_not_convert`, ensuring that the quantization process is applied only to the relevant modules.\n\n3. **Mapping Quantization Parameters**: The core of the quantization process happens when `prepare_for_hqq_linear` maps the quantization parameters to each linear tag. This is done by creating a dictionary (`patch_params`) where each key is a linear tag and the value is the corresponding quantization parameter. If specific quantization parameters are not provided for a tag, a default configuration is applied. This mapping ensures that each linear module (identified uniquely by its tag) receives a tailored set of quantization parameters.\n\n4. **Updating Model Configuration**: After mapping the quantization parameters, the `prepare_for_hqq_linear` function updates the model's configuration to include these parameters, ensuring that each linear module's configuration reflects its unique quantization settings. This step is crucial for the actual quantization process, where linear modules might be replaced with their quantized counterparts (`HQQLinear`), depending on the configuration.\n\n5. **Final Verification and Logging**: The function checks if any linear modules have been replaced and logs a warning if no modules were found for quantization. This serves as a final check to ensure that the quantization process has been applied as expected.\n\nIn summary, the `HqqHfQuantizer._process_model_before_weight_loading` function ensures that each linear module is uniquely and appropriately quantized by meticulously tagging each module, applying a tailored quantization configuration, and updating the model to reflect these settings. This process is designed to optimize the model's performance for deployment, ensuring that each module operates efficiently and accurately under the constraints of quantization.",
- "expected_identifiers": [
- "get_linear_tags",
- "autoname_modules",
- "prepare_for_hqq_linear",
- "patch_params"
- ]
- },
- {
- "repo": "huggingface/transformers",
- "commit": "7bb1c99800d235791dace10305731f377db8077b",
- "context_files": [
- "huggingface/transformers/src/transformers/models/esm/modeling_esmfold.py",
- "huggingface/transformers/src/transformers/models/esm/openfold_utils/loss.py"
- ],
- "question": "When analyzing a protein sequence with low complexity using `EsmForProteinFolding.forward`, how is the stability and definition of the output ensured?",
- "answer": "When analyzing a protein sequence with low complexity using the `EsmForProteinFolding.forward` function, the stability and definition of the output are ensured through several key mechanisms embedded within the function's implementation, particularly in how it handles normalization and potential numerical instabilities.\n\n1. **Normalization of Residue Weights**: In the `compute_tm` function, residue weights are normalized by their sum, with the addition of a small constant `eps` (epsilon) to prevent division by zero. This is crucial when dealing with sequences of low complexity where certain residues might be overrepresented or underrepresented. The normalization step is represented in the code as:\n ```python\n normed_residue_mask = residue_weights / (eps + residue_weights.sum())\n ```\n Here, `eps` acts as a safeguard against division by zero, ensuring that the function remains numerically stable and produces defined outputs even when the sum of residue weights is extremely small or zero.\n\n2. **Weighted Average Calculation**: The function calculates a weighted average of the Template Modeling (TM) scores across different bins, which is critical for obtaining a reliable TM score. This is done using the normalized residue weights, ensuring that each residue's contribution is proportionate to its presence, thus maintaining accuracy and stability in the final score calculation:\n ```python\n per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)\n ```\n This step aggregates the TM scores across all residues, factoring in their normalized weights, which is particularly important in low complexity sequences where certain residues might dominate.\n\n3. **Handling of Edge Cases**: The use of `eps` in the normalization process is a direct method to handle edge cases, such as sequences with low complexity or unusual amino acid distributions. By ensuring that the denominator in the normalization step is never zero, the function avoids potential runtime errors (like NaN or infinite values), which could disrupt the analysis process.\n\n4. **Integration within `EsmForProteinFolding.forward`**: The stability and definition of outputs from the `EsmForProteinFolding.forward` function are further supported by how `compute_tm` integrates with other components of the model. The TM scores computed are used alongside other structural predictions, contributing to a comprehensive evaluation of the predicted protein structures. This integration ensures that the outputs are not only stable and defined but also meaningful in the context of protein structure prediction.\n\nIn summary, the `EsmForProteinFolding.forward` function ensures stable and defined outputs for protein structure predictions, particularly in scenarios involving low complexity sequences, by employing robust normalization techniques and handling potential numerical instabilities through the careful addition of a small epsilon value in critical calculations. This approach guarantees that the function can reliably process a wide range of input data without encountering computational errors.",
- "expected_identifiers": [
- "normed_residue_mask",
- "eps",
- "residue_weights / (eps + residue_weights.sum())",
- "torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)"
- ]
- },
- {
- "repo": "huggingface/transformers",
- "commit": "7bb1c99800d235791dace10305731f377db8077b",
- "context_files": [
- "huggingface/transformers/src/transformers/pipelines/question_answering.py",
- "huggingface/transformers/src/transformers/data/processors/squad.py"
- ],
- "question": "In a scenario where the textual data includes unusually lengthy paragraphs, how does `QuestionAnsweringPipeline.preprocess` ensure comprehensive coverage of all context tokens in the model's input sequences?",
- "answer": "In scenarios where the textual data includes unusually lengthy paragraphs that exceed the model's maximum input length, the `QuestionAnsweringPipeline.preprocess` function ensures comprehensive coverage of all context tokens in the model's input sequences through a meticulous management of tokenization and handling of overflow tokens. This process is crucial for maintaining the integrity and continuity of the context information, which is essential for the model to accurately answer questions based on the provided context.\n\n### Step-by-Step Explanation:\n\n1. **Tokenization and Pairing**:\n The function begins by tokenizing the question and context separately. Depending on the tokenizer's configuration (`tokenizer.padding_side`), the question and context are arranged in a specific order (either question first or context first). This is handled in the lines where `encoded_inputs` is defined using `self.tokenizer(text, text_pair, ...)`. \n\n2. **Handling Long Contexts with Overflow Tokens**:\n The key parameter here is `return_overflowing_tokens=True` within the tokenizer call. This setting ensures that when the combined length of the question and context exceeds `max_seq_len`, the tokenizer automatically generates additional input sequences that contain the \"overflow\" tokens from the context. These sequences overlap by a number of tokens defined by `doc_stride`, which is calculated as `min(max_seq_len // 2, 128)`.\n\n3. **Creating Overlapping Spans**:\n The overlapping spans are crucial for ensuring that tokens near the boundaries of a sequence are also seen in different contextual surroundings, enhancing the model's ability to understand and answer questions about tokens that appear near the maximum sequence length limit. This overlap is managed by the `stride` parameter in the tokenizer, which is set to `doc_stride`.\n\n4. **Feature Construction**:\n For each span generated from the overflowing tokens, the function constructs a feature object that includes not only the token ids (`input_ids`) but also attention masks, token type ids, and a special mask (`p_mask`) which indicates which tokens can be part of an answer. The `p_mask` is particularly important as it helps the model distinguish between context tokens (potential answer locations) and non-context tokens (like those belonging to the question or special tokens).\n\n5. **Yielding Processed Features**:\n Each feature constructed from the spans is then yielded one by one, with additional metadata such as whether it is the last feature of the example. This is handled in the loop `for i, feature in enumerate(features):` where each feature is prepared according to the model's requirements, potentially converting them into tensors suitable for the model's computation framework (PyTorch or TensorFlow).\n\n### Conclusion:\n\nBy managing the tokenization and overflow tokens effectively, `QuestionAnsweringPipeline.preprocess` ensures that every token from a lengthy context is included in at least one input sequence to the model. This comprehensive coverage is achieved through the creation of multiple, overlapping input sequences that ensure no contextual information is lost, thereby enabling the model to perform accurately even with contexts that far exceed its maximum input length. This methodical approach is crucial for handling real-world data where lengthy paragraphs are common, ensuring robust model performance across varied datasets.",
- "expected_identifiers": [
- "return_overflowing_tokens=True",
- "doc_stride",
- "p_mask"
- ]
- },
- {
- "repo": "huggingface/transformers",
- "commit": "7bb1c99800d235791dace10305731f377db8077b",
- "context_files": [
- "huggingface/transformers/examples/research_projects/movement-pruning/masked_run_squad.py",
- "huggingface/transformers/src/transformers/data/processors/squad.py"
- ],
- "question": "Given the challenge of training models on datasets with varying context lengths, how does `load_and_cache_examples` in `examples/research_projects/movement-pruning/masked_run_squad.py` ensure the production of a dataset that supports accurate answer position prediction by the models?",
- "answer": "The `load_and_cache_examples` function in `examples/research_projects/movement-pruning/masked_run_squad.py` is designed to ensure that the dataset produced supports accurate answer position prediction by models, even when dealing with varying context lengths. This is achieved through a series of steps that involve careful handling and processing of the data, particularly when contexts exceed the model's maximum sequence length. Here's how the function manages this:\n\n1. **Data Loading and Caching**: Initially, the function checks if a cached version of the processed data exists. If it does, it loads the features, dataset, and examples directly from the cache, bypassing the need for reprocessing. This step ensures efficiency and consistency in data usage across different training runs.\n\n2. **Dynamic Data Processing**: If no cached data is available, the function processes the raw data to generate features suitable for model training. This involves tokenizing the text and converting the SQuAD examples into features that models can understand and learn from.\n\n3. **Handling Extended Contexts**: The core of handling varying context lengths lies in the `squad_convert_examples_to_features` function, which is called within `load_and_cache_examples`. This function uses `squad_convert_example_to_features` to process each example individually.\n\n4. **Segmentation and Token Index Adjustment**: In `squad_convert_example_to_features`, the context is potentially split into multiple spans if its length exceeds the model's maximum sequence length. This is crucial because it allows the model to handle long contexts by breaking them down into manageable parts. Each span is processed to ensure that the start and end positions of answers are correctly adjusted within the tokenized context. This adjustment is handled by the `_improve_answer_span` function, which ensures that the answer spans are accurately placed within the tokens, even if the context is segmented.\n\n5. **Feature Construction**: Each span is then converted into a set of features, including input IDs, attention masks, token type IDs, and the positions of the answers. Special care is taken to mark tokens that cannot be part of the answers (using a p_mask), and to identify the maximum context for each token, which is critical for understanding which part of the split context a token belongs to.\n\n6. **Dataset Compilation**: After processing, the features are compiled into a dataset format (either PyTorch or TensorFlow, based on the configuration). This dataset includes all necessary information for the model to learn from, including the context, the question, and the correct positions of the answers.\n\nBy carefully managing the tokenization, segmentation, and feature construction processes, `load_and_cache_examples` ensures that the dataset it produces allows models to accurately predict answer positions, regardless of the length of the context. This capability is essential for training robust question-answering models that can handle real-world data, where context lengths can vary significantly.",
- "expected_identifiers": [
- "squad_convert_examples_to_features",
- "squad_convert_example_to_features",
- "_improve_answer_span",
- "p_mask"
- ]
- },
- {
- "repo": "huggingface/transformers",
- "commit": "7bb1c99800d235791dace10305731f377db8077b",
- "context_files": [
- "huggingface/transformers/src/transformers/modeling_flax_utils.py",
- "huggingface/transformers/src/transformers/utils/hub.py"
- ],
- "question": "In a scenario where network conditions are suboptimal, how does `FlaxPreTrainedModel.from_pretrained` manage to reduce the model loading time?",
- "answer": "In scenarios where network conditions are suboptimal, the `FlaxPreTrainedModel.from_pretrained` function effectively reduces model loading time by leveraging a sophisticated caching mechanism. This mechanism is crucial for managing the download and storage of model shards, ensuring efficient and faster model initialization.\n\n### Caching Mechanism:\nThe function first checks if the required model shards are already available in the local cache before attempting any network requests. This is achieved through the `try_to_load_from_cache` function, which inspects the cache for the presence of the last shard of the model. If the last shard is found in the cache, it is likely that all previous shards are also cached, thus avoiding the need for further network requests.\n\n### Download and Cache Management:\nIf the shards are not found in the cache, `FlaxPreTrainedModel.from_pretrained` proceeds to download them. Each shard's presence is verified using the `cached_file` function, which handles the downloading and caching of the shard if it is not already present. This function also supports resuming downloads, which is particularly useful in suboptimal network conditions where downloads might be interrupted.\n\n### Efficient Shard Handling:\nThe function `get_checkpoint_shard_files` is specifically designed to manage sharded model files. It reads the checkpoint index file to determine all the necessary shards for the model and then ensures each shard is either fetched from the cache or downloaded. This process is streamlined by the use of a progress bar (managed by `tqdm`), which provides visual feedback on the download process, enhancing user experience especially in network-constrained environments.\n\n### Impact of Caching on Model Loading Time:\nBy prioritizing cached shards, `FlaxPreTrainedModel.from_pretrained` significantly reduces the dependency on network bandwidth and stability. This is particularly beneficial in scenarios with limited network resources, as it minimizes the time spent in downloading model components. The caching mechanism ensures that once a model shard is downloaded and stored locally, subsequent loads of the same model will utilize the cached versions, thereby bypassing the network entirely and leading to much faster model initialization times.\n\n### Conclusion:\nThe caching strategy employed by `FlaxPreTrainedModel.from_pretrained` not only optimizes the use of network resources but also ensures consistent and reduced model loading times, regardless of network conditions. This approach is instrumental in scenarios where models need to be switched frequently or reloaded, providing a seamless and efficient user experience.",
- "expected_identifiers": [
- "try_to_load_from_cache",
- "cached_file",
- "get_checkpoint_shard_files",
- "tqdm"
- ]
- },
- {
- "repo": "huggingface/transformers",
- "commit": "7bb1c99800d235791dace10305731f377db8077b",
- "context_files": [
- "huggingface/transformers/examples/research_projects/information-gain-filtration/run_clm_igf.py",
- "huggingface/transformers/examples/research_projects/information-gain-filtration/igf/igf.py"
- ],
- "question": "In light of recent dataset size restrictions for training purposes, how does `generate_n_pairs` maintain compliance by ensuring the objective set adheres to the specified size and article length requirements?",
- "answer": "The `generate_n_pairs` function ensures compliance with dataset size restrictions by meticulously managing the creation of the objective set through its subordinate function `generate_datasets`. This process is governed by specific parameters and conditions set within the code to meet the required criteria of size and article length.\n\n1. **Size of the Objective Set**: The function `generate_datasets` is designed to create an objective set that contains exactly the number of articles specified by the `number` parameter, which is passed from `generate_n_pairs` as `size_objective_set`. In the provided code, this value is set to 100. The loop within `generate_datasets` that populates the `objective_set` list includes a condition to break once the length of this list reaches the specified `number` (see the line `if len(objective_set) >= number: break`). This ensures that no more than 100 articles are added to the objective set, directly adhering to the dataset size restrictions.\n\n2. **Article Length Management**: The function also manages the length of each article in the objective set based on the `context_len` parameter. If `trim` is set to `True`, the function trims the articles to ensure they do not exceed the specified `context_len`. This is achieved by selecting a starting point randomly within the article and then slicing the article to obtain a segment of the specified `context_len` (see the line `objective_set.append(example[0, start : start + context_len])`). This ensures that each article in the objective set adheres to the length restrictions.\n\n3. **Compliance with Regulations**: By strictly controlling both the number of articles and their lengths as described, `generate_n_pairs` ensures that the objective set complies with new regulations requiring training datasets to contain no more than 100 articles, each of a specified maximum length. This compliance is crucial for ethical review and adherence to training dataset standards.\n\nIn summary, `generate_n_pairs` maintains compliance with dataset size and article length restrictions through careful implementation in `generate_datasets`, which explicitly controls the size of the objective set and trims articles to the required length based on the parameters provided. This methodical approach ensures that the objective set meets specified criteria, crucial for adhering to regulatory standards.",
- "expected_identifiers": [
- "generate_n_pairs",
- "generate_datasets",
- "size_objective_set",
- "context_len"
- ]
- }
-]
\ No newline at end of file
diff --git a/sage/__init__.py b/code_chatbot/__init__.py
similarity index 100%
rename from sage/__init__.py
rename to code_chatbot/__init__.py
diff --git a/code_chatbot/agent_workflow.py b/code_chatbot/agent_workflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ce27580136c80654f92f8ebc4b9fa1e59335183
--- /dev/null
+++ b/code_chatbot/agent_workflow.py
@@ -0,0 +1,135 @@
+
+from typing import TypedDict, Annotated, Sequence
+import operator
+from langchain_core.messages import BaseMessage
+from langchain_core.tools import tool
+from langgraph.graph import StateGraph, END
+from langgraph.prebuilt import ToolNode
+from code_chatbot.rate_limiter import get_rate_limiter
+
+# Define State
+class AgentState(TypedDict):
+ messages: Annotated[Sequence[BaseMessage], operator.add]
+
+def create_agent_graph(llm, retriever, repo_name: str = "Codebase", repo_dir: str = ".", provider: str = "gemini", code_analyzer=None):
+ """
+ Creates a LangGraph for the Code Chatbot.
+ Enables: Search -> Read File -> Reason -> Search -> Answer.
+ Uses adaptive rate limiting to maximize usage within free tier.
+ """
+
+ from pydantic import BaseModel, Field
+
+ class SearchInput(BaseModel):
+ query: str = Field(description="The query string to search for in the codebase.")
+
+ # 1. Wrap Retriever as a Tool
+ @tool("search_codebase", args_schema=SearchInput)
+ def search_codebase(query: str):
+ """
+ Search the codebase for code snippets relevant to the query.
+ Returns top 5 most relevant code sections with file paths.
+ Use this when you need to find specific functions, classes, or implementations.
+ You can call this multiple times with different queries to gather comprehensive information.
+ """
+ docs = retriever.invoke(query)
+ result = ""
+ # Increased to 5 results * 2000 chars = ~10000 chars (~2500 tokens) - much better context
+ for i, doc in enumerate(docs[:5]):
+ fp = doc.metadata.get('file_path', 'unknown')
+ # Get relative path for cleaner display
+ import os
+ display_path = os.path.basename(fp) if fp != 'unknown' else 'unknown'
+ content = doc.page_content[:2000] # Increased from 1000 to 2000
+ result += f"--- Result {i+1}: {display_path} ---\n{content}\n\n"
+
+ if not result:
+ return "No relevant code found. Try a different search query or use list_files to explore the codebase structure."
+
+ return result
+
+ # 2. Import File System Tools
+ from code_chatbot.tools import get_filesystem_tools, get_call_graph_tools
+
+ # 3. Combine Tools
+ fs_tools = get_filesystem_tools(repo_dir)
+ call_graph_tools = get_call_graph_tools(code_analyzer) if code_analyzer else []
+ tools = fs_tools + [search_codebase] + call_graph_tools
+
+ # 4. Bind to LLM
+ # Note: Not all LLMs support bind_tools cleanly, but Gemini/Groq(Llama3) do via LangChain
+ model_with_tools = llm.bind_tools(tools)
+
+ # 5. Define Nodes
+ # Get rate limiter for this provider
+ rate_limiter = get_rate_limiter(provider)
+
+ def agent(state):
+ messages = state["messages"]
+ import logging
+
+ logger = logging.getLogger(__name__)
+
+ # Smart adaptive delay - only waits when approaching rate limit
+ rate_limiter.wait_if_needed()
+
+ # Retry loop for 429 errors
+ # FAIL FAST: Only retry twice (5s, 10s) = 15s max delay.
+ # If it still fails, we want to bubble up to rag.py to trigger Linear RAG fallback.
+ for i in range(2):
+ try:
+ response = model_with_tools.invoke(messages)
+ # Track usage for statistics (if available in response metadata)
+ try:
+ usage = getattr(response, 'usage_metadata', None)
+ if usage:
+ rate_limiter.record_usage(
+ input_tokens=getattr(usage, 'input_tokens', 0),
+ output_tokens=getattr(usage, 'output_tokens', 0)
+ )
+ except:
+ pass
+
+ return {"messages": [response]}
+ except Exception as e:
+ # Catch both Gemini 429 and Groq Overloaded errors
+ if any(err in str(e) for err in ["429", "RESOURCE_EXHAUSTED", "rate_limit_exceeded"]):
+ import time
+ wait = 5 * (2 ** i) # 5, 10
+ logger.warning(f"⚠️ Rate limit hit. Cooling down for {wait}s...")
+ time.sleep(wait)
+ if i == 1: raise e
+ else:
+ raise e
+ return {"messages": []} # Should not reach here
+
+ tool_node = ToolNode(tools)
+
+ # 6. Define Limits (Graph recursion limit is set in .compile(), but we can add logic here)
+
+ # 7. Build Graph
+ workflow = StateGraph(AgentState)
+ workflow.add_node("agent", agent)
+ workflow.add_node("tools", tool_node)
+
+ workflow.set_entry_point("agent")
+
+ # Conditional Edge
+ def should_continue(state):
+ last_message = state["messages"][-1]
+
+ # If there is no tool call, then we finish
+ if not last_message.tool_calls:
+ return END
+
+ # Otherwise context switch to tools
+ return "tools"
+
+ workflow.add_conditional_edges(
+ "agent",
+ should_continue,
+ )
+
+ workflow.add_edge("tools", "agent")
+
+ return workflow.compile()
diff --git a/code_chatbot/ast_analysis.py b/code_chatbot/ast_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..2964d6408dcf86d50631a6432e7827e0fa27a9d0
--- /dev/null
+++ b/code_chatbot/ast_analysis.py
@@ -0,0 +1,516 @@
+"""
+Enhanced Code Analysis with AST + Call Graph + Control Flow
+
+This module provides comprehensive code analysis using:
+1. AST (Abstract Syntax Tree) - Code structure
+2. Call Graph - Function-to-function relationships
+3. Import Graph - Module dependencies
+4. Class Hierarchy - Inheritance relationships
+
+Uses tree-sitter for multi-language support.
+"""
+
+import logging
+import networkx as nx
+import os
+from typing import List, Dict, Optional, Set, Tuple
+from dataclasses import dataclass, field
+from tree_sitter import Language, Parser
+import tree_sitter_python
+import tree_sitter_javascript
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class FunctionInfo:
+ """Information about a function/method"""
+ name: str
+ file_path: str
+ start_line: int
+ end_line: int
+ is_method: bool = False
+ class_name: Optional[str] = None
+ calls: List[str] = field(default_factory=list)
+ parameters: List[str] = field(default_factory=list)
+
+ @property
+ def full_name(self) -> str:
+ if self.class_name:
+ return f"{self.class_name}.{self.name}"
+ return self.name
+
+ @property
+ def node_id(self) -> str:
+ return f"{self.file_path}::{self.full_name}"
+
+
+@dataclass
+class ClassInfo:
+ """Information about a class"""
+ name: str
+ file_path: str
+ start_line: int
+ end_line: int
+ bases: List[str] = field(default_factory=list) # Parent classes
+ methods: List[str] = field(default_factory=list)
+
+
+@dataclass
+class ImportInfo:
+ """Information about an import"""
+ module: str
+ names: List[str] = field(default_factory=list) # Specific names imported
+ is_from_import: bool = False
+
+
+class EnhancedCodeAnalyzer:
+ """
+ Enhanced code analyzer that builds:
+ - AST-based structure graph
+ - Function call graph
+ - Import dependency graph
+ - Class hierarchy graph
+ """
+
+ def __init__(self):
+ # Main knowledge graph
+ self.graph = nx.DiGraph()
+
+ # Specialized indices for faster lookups
+ self.functions: Dict[str, FunctionInfo] = {} # node_id -> FunctionInfo
+ self.classes: Dict[str, ClassInfo] = {} # node_id -> ClassInfo
+ self.imports: Dict[str, List[ImportInfo]] = {} # file_path -> imports
+ self.definitions: Dict[str, List[str]] = {} # name -> [node_ids]
+
+ # Track unresolved calls for later resolution
+ self.unresolved_calls: List[Tuple[str, str, int]] = [] # (caller_id, callee_name, line)
+
+ # Parsers
+ self.parsers = {}
+ self._init_parsers()
+
+ def _init_parsers(self):
+ """Initialize tree-sitter parsers for supported languages."""
+ try:
+ # Python
+ py_language = Language(tree_sitter_python.language())
+ py_parser = Parser(py_language)
+ self.parsers['python'] = py_parser
+ self.parsers['py'] = py_parser
+
+ # JavaScript
+ js_language = Language(tree_sitter_javascript.language())
+ js_parser = Parser(js_language)
+ self.parsers['javascript'] = js_parser
+ self.parsers['js'] = js_parser
+ self.parsers['jsx'] = js_parser
+
+ except Exception as e:
+ logger.error(f"Error initializing parsers: {e}")
+
+ def add_file(self, file_path: str, content: str):
+ """Parse a file and add it to the knowledge graph."""
+ ext = file_path.split('.')[-1].lower()
+ parser = self.parsers.get(ext)
+
+ if not parser:
+ return
+
+ try:
+ tree = parser.parse(bytes(content, "utf8"))
+ root_node = tree.root_node
+
+ # Add file node
+ self.graph.add_node(
+ file_path,
+ type="file",
+ name=os.path.basename(file_path),
+ language=ext
+ )
+
+ # Extract all symbols
+ self._extract_symbols(root_node, file_path, content)
+
+ except Exception as e:
+ logger.error(f"Failed to parse {file_path}: {e}")
+
+ def _extract_symbols(self, node, file_path: str, content: str,
+ current_class: Optional[str] = None,
+ current_function: Optional[str] = None):
+ """Recursively extract symbols from AST node."""
+
+ # ========== IMPORTS ==========
+ if node.type == "import_statement":
+ self._process_import(node, file_path, content)
+
+ elif node.type == "import_from_statement":
+ self._process_from_import(node, file_path, content)
+
+ # ========== CLASSES ==========
+ elif node.type == "class_definition":
+ class_info = self._process_class(node, file_path, content)
+ if class_info:
+ # Recurse into class body with class context
+ for child in node.children:
+ if child.type == "block":
+ self._extract_symbols(child, file_path, content,
+ current_class=class_info.name)
+ return # Don't recurse again below
+
+ # ========== FUNCTIONS/METHODS ==========
+ elif node.type == "function_definition":
+ func_info = self._process_function(node, file_path, content, current_class)
+ if func_info:
+ # Recurse into function body to find calls
+ for child in node.children:
+ if child.type == "block":
+ self._extract_symbols(child, file_path, content,
+ current_class=current_class,
+ current_function=func_info.node_id)
+ return # Don't recurse again below
+
+ # ========== FUNCTION CALLS ==========
+ elif node.type == "call":
+ self._process_call(node, file_path, content, current_function or file_path)
+
+ # Recurse into children
+ for child in node.children:
+ self._extract_symbols(child, file_path, content,
+ current_class, current_function)
+
+ def _process_import(self, node, file_path: str, content: str):
+ """Process import statement."""
+ # import module1, module2
+ for child in node.children:
+ if child.type == "dotted_name":
+ module_name = self._get_text(child, content)
+ import_info = ImportInfo(module=module_name)
+
+ if file_path not in self.imports:
+ self.imports[file_path] = []
+ self.imports[file_path].append(import_info)
+
+ # Add import edge
+ self.graph.add_edge(file_path, module_name, relation="imports")
+
+ def _process_from_import(self, node, file_path: str, content: str):
+ """Process from X import Y statement."""
+ module_name = None
+ names = []
+
+ for child in node.children:
+ if child.type == "dotted_name" and module_name is None:
+ module_name = self._get_text(child, content)
+ elif child.type == "import_from_list":
+ for name_node in child.children:
+ if name_node.type == "aliased_import":
+ name = self._get_text(name_node.children[0], content)
+ names.append(name)
+ elif name_node.type == "identifier":
+ names.append(self._get_text(name_node, content))
+
+ if module_name:
+ import_info = ImportInfo(module=module_name, names=names, is_from_import=True)
+ if file_path not in self.imports:
+ self.imports[file_path] = []
+ self.imports[file_path].append(import_info)
+
+ # Add import edge
+ self.graph.add_edge(file_path, module_name, relation="imports")
+
+ # Register imported names as potential definitions
+ for name in names:
+ if name not in self.definitions:
+ self.definitions[name] = []
+ self.definitions[name].append(f"{module_name}.{name}")
+
+ def _process_class(self, node, file_path: str, content: str) -> Optional[ClassInfo]:
+ """Process class definition."""
+ name_node = node.child_by_field_name("name")
+ if not name_node:
+ return None
+
+ class_name = self._get_text(name_node, content)
+ node_id = f"{file_path}::{class_name}"
+
+ # Get base classes
+ bases = []
+ for child in node.children:
+ if child.type == "argument_list":
+ for arg in child.children:
+ if arg.type == "identifier":
+ bases.append(self._get_text(arg, content))
+
+ class_info = ClassInfo(
+ name=class_name,
+ file_path=file_path,
+ start_line=node.start_point[0] + 1,
+ end_line=node.end_point[0] + 1,
+ bases=bases
+ )
+
+ self.classes[node_id] = class_info
+
+ # Add to graph
+ self.graph.add_node(
+ node_id,
+ type="class",
+ name=class_name,
+ start_line=class_info.start_line,
+ end_line=class_info.end_line
+ )
+
+ self.graph.add_edge(file_path, node_id, relation="defines")
+
+ # Add inheritance edges
+ for base in bases:
+ self.graph.add_edge(node_id, base, relation="inherits_from")
+
+ # Register definition
+ if class_name not in self.definitions:
+ self.definitions[class_name] = []
+ self.definitions[class_name].append(node_id)
+
+ return class_info
+
+ def _process_function(self, node, file_path: str, content: str,
+ current_class: Optional[str] = None) -> Optional[FunctionInfo]:
+ """Process function/method definition."""
+ name_node = node.child_by_field_name("name")
+ if not name_node:
+ return None
+
+ func_name = self._get_text(name_node, content)
+
+ # Get parameters
+ params = []
+ params_node = node.child_by_field_name("parameters")
+ if params_node:
+ for child in params_node.children:
+ if child.type == "identifier":
+ params.append(self._get_text(child, content))
+ elif child.type == "typed_parameter":
+ name = child.child_by_field_name("name")
+ if name:
+ params.append(self._get_text(name, content))
+
+ func_info = FunctionInfo(
+ name=func_name,
+ file_path=file_path,
+ start_line=node.start_point[0] + 1,
+ end_line=node.end_point[0] + 1,
+ is_method=current_class is not None,
+ class_name=current_class,
+ parameters=params
+ )
+
+ node_id = func_info.node_id
+ self.functions[node_id] = func_info
+
+ # Add to graph
+ self.graph.add_node(
+ node_id,
+ type="function" if not current_class else "method",
+ name=func_name,
+ full_name=func_info.full_name,
+ start_line=func_info.start_line,
+ end_line=func_info.end_line,
+ parameters=",".join(params)
+ )
+
+ # Link to parent (file or class)
+ if current_class:
+ class_id = f"{file_path}::{current_class}"
+ self.graph.add_edge(class_id, node_id, relation="has_method")
+ else:
+ self.graph.add_edge(file_path, node_id, relation="defines")
+
+ # Register definition
+ if func_name not in self.definitions:
+ self.definitions[func_name] = []
+ self.definitions[func_name].append(node_id)
+
+ return func_info
+
+ def _process_call(self, node, file_path: str, content: str, caller_id: str):
+ """Process function call."""
+ func_node = node.child_by_field_name("function")
+ if not func_node:
+ return
+
+ callee_name = self._get_text(func_node, content)
+ call_line = node.start_point[0] + 1
+
+ # Track call in function info
+ if caller_id in self.functions:
+ self.functions[caller_id].calls.append(callee_name)
+
+ # Store for later resolution
+ self.unresolved_calls.append((caller_id, callee_name, call_line))
+
+ def _get_text(self, node, content: str) -> str:
+ """Get text content of a node."""
+ return content[node.start_byte:node.end_byte]
+
+ def resolve_call_graph(self):
+ """Resolve all function calls to their definitions."""
+ resolved_count = 0
+
+ for caller_id, callee_name, line in self.unresolved_calls:
+ # Handle method calls like "self.method" or "obj.method"
+ simple_name = callee_name.split(".")[-1]
+
+ # Try to find definition
+ target_ids = []
+
+ # Check direct match
+ if callee_name in self.definitions:
+ target_ids.extend(self.definitions[callee_name])
+
+ # Check simple name (for methods)
+ if simple_name in self.definitions and simple_name != callee_name:
+ target_ids.extend(self.definitions[simple_name])
+
+ # Add call edges
+ for target_id in target_ids:
+ self.graph.add_edge(
+ caller_id,
+ target_id,
+ relation="calls",
+ line=line
+ )
+ resolved_count += 1
+
+ logger.info(f"Resolved {resolved_count} function calls in call graph")
+
+ def get_callers(self, function_name: str) -> List[str]:
+ """Find all functions that call the specified function."""
+ callers = []
+
+ # Find the function's node_id
+ target_ids = self.definitions.get(function_name, [])
+
+ for target_id in target_ids:
+ # Find incoming "calls" edges
+ for pred in self.graph.predecessors(target_id):
+ edge_data = self.graph.get_edge_data(pred, target_id)
+ if edge_data and edge_data.get("relation") == "calls":
+ callers.append(pred)
+
+ return callers
+
+ def get_callees(self, function_name: str) -> List[str]:
+ """Find all functions called by the given function."""
+ callees = []
+
+ # Find the function's node_id
+ caller_ids = self.definitions.get(function_name, [])
+
+ for caller_id in caller_ids:
+ # Find outgoing "calls" edges
+ for succ in self.graph.successors(caller_id):
+ edge_data = self.graph.get_edge_data(caller_id, succ)
+ if edge_data and edge_data.get("relation") == "calls":
+ callees.append(succ)
+
+ return callees
+
+ def get_call_chain(self, start_func: str, end_func: str, max_depth: int = 5) -> List[List[str]]:
+ """Find call paths from start_func to end_func."""
+ paths = []
+
+ start_ids = self.definitions.get(start_func, [])
+ end_ids = self.definitions.get(end_func, [])
+
+ for start_id in start_ids:
+ for end_id in end_ids:
+ try:
+ for path in nx.all_simple_paths(self.graph, start_id, end_id, cutoff=max_depth):
+ # Filter to only show call edges
+ call_path = [start_id]
+ for i in range(len(path) - 1):
+ edge = self.graph.get_edge_data(path[i], path[i+1])
+ if edge and edge.get("relation") == "calls":
+ call_path.append(path[i+1])
+ if len(call_path) > 1:
+ paths.append(call_path)
+ except nx.NetworkXNoPath:
+ continue
+
+ return paths
+
+ def get_file_dependencies(self, file_path: str) -> Dict[str, List[str]]:
+ """Get all dependencies of a file (imports, calls to other files)."""
+ deps = {
+ "imports": [],
+ "calls_to": [],
+ "called_by": []
+ }
+
+ # Direct imports
+ deps["imports"] = [imp.module for imp in self.imports.get(file_path, [])]
+
+ # Functions in this file that call functions in other files
+ for func_id, func_info in self.functions.items():
+ if func_info.file_path == file_path:
+ for callee in self.get_callees(func_info.name):
+ callee_file = callee.split("::")[0]
+ if callee_file != file_path and callee_file not in deps["calls_to"]:
+ deps["calls_to"].append(callee_file)
+
+ # Functions in other files that call functions in this file
+ for func_id, func_info in self.functions.items():
+ if func_info.file_path == file_path:
+ for caller in self.get_callers(func_info.name):
+ caller_file = caller.split("::")[0]
+ if caller_file != file_path and caller_file not in deps["called_by"]:
+ deps["called_by"].append(caller_file)
+
+ return deps
+
+ def get_related_nodes(self, node_id: str, depth: int = 2) -> List[str]:
+ """Get nodes related to the given node via graph traversal."""
+ if node_id not in self.graph:
+ # Try to find by name
+ if node_id in self.definitions:
+ node_ids = self.definitions[node_id]
+ all_related = []
+ for nid in node_ids:
+ all_related.extend(list(nx.bfs_tree(self.graph, nid, depth_limit=depth)))
+ return list(set(all_related))
+ return []
+
+ return list(nx.bfs_tree(self.graph, node_id, depth_limit=depth))
+
+ def get_statistics(self) -> Dict:
+ """Get analysis statistics."""
+ return {
+ "total_nodes": self.graph.number_of_nodes(),
+ "total_edges": self.graph.number_of_edges(),
+ "files": len([n for n, d in self.graph.nodes(data=True) if d.get("type") == "file"]),
+ "classes": len(self.classes),
+ "functions": len([f for f in self.functions.values() if not f.is_method]),
+ "methods": len([f for f in self.functions.values() if f.is_method]),
+ "imports": sum(len(imps) for imps in self.imports.values()),
+ "call_edges": len([1 for _, _, d in self.graph.edges(data=True) if d.get("relation") == "calls"])
+ }
+
+ def save_graph(self, path: str):
+ """Save the graph to a GraphML file."""
+ # Resolve call graph first
+ self.resolve_call_graph()
+
+ # Log statistics
+ stats = self.get_statistics()
+ logger.info(f"Graph Statistics: {stats}")
+
+ nx.write_graphml(self.graph, path)
+ logger.info(f"Graph saved to {path}")
+
+
+# Backward compatibility alias
+class ASTGraphBuilder(EnhancedCodeAnalyzer):
+ """Alias for backward compatibility with existing code."""
+ pass
diff --git a/code_chatbot/chunker.py b/code_chatbot/chunker.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ee80a60114d3df6bff3e13a76e9dac135988ef6
--- /dev/null
+++ b/code_chatbot/chunker.py
@@ -0,0 +1,251 @@
+"""Enhanced chunker with proper token counting and merging strategies, inspired by Sage."""
+
+import logging
+import os
+from typing import List, Dict, Any, Optional
+from dataclasses import dataclass
+from functools import cached_property
+
+import pygments
+import tiktoken
+from langchain_core.documents import Document
+from tree_sitter import Language, Parser, Node
+import tree_sitter_python
+import tree_sitter_javascript
+
+logger = logging.getLogger(__name__)
+tokenizer = tiktoken.get_encoding("cl100k_base")
+
+
+@dataclass
+class FileChunk:
+ """Represents a chunk of code with byte positions."""
+ file_content: str
+ file_metadata: Dict
+ start_byte: int
+ end_byte: int
+
+ @cached_property
+ def filename(self):
+ if "file_path" not in self.file_metadata:
+ raise ValueError("file_metadata must contain a 'file_path' key.")
+ return self.file_metadata["file_path"]
+
+ @cached_property
+ def content(self) -> str:
+ """The text content to be embedded. Includes filename for context."""
+ return self.filename + "\n\n" + self.file_content[self.start_byte : self.end_byte]
+
+ @cached_property
+ def num_tokens(self):
+ """Number of tokens in this chunk."""
+ return len(tokenizer.encode(self.content, disallowed_special=()))
+
+ def to_document(self) -> Document:
+ """Convert to LangChain Document."""
+ chunk_type = self.file_metadata.get("chunk_type", "code")
+ name = self.file_metadata.get("name", None)
+
+ return Document(
+ page_content=self.content,
+ metadata={
+ **self.file_metadata,
+ "id": f"{self.filename}_{self.start_byte}_{self.end_byte}",
+ "start_byte": self.start_byte,
+ "end_byte": self.end_byte,
+ "length": self.end_byte - self.start_byte,
+ "chunk_type": chunk_type,
+ "name": name,
+ }
+ )
+
+
+class StructuralChunker:
+ """
+ Chunks code files based on their AST structure (Functions, Classes) using Tree-sitter.
+ Uses proper token counting with tiktoken and implements merging strategies to avoid
+ pathologically small chunks.
+ """
+ def __init__(self, max_tokens: int = 800):
+ self.max_tokens = max_tokens
+ self.parsers = {}
+ self._init_parsers()
+
+ def _init_parsers(self):
+ try:
+ self.parsers['py'] = Parser(Language(tree_sitter_python.language()))
+ self.parsers['python'] = self.parsers['py']
+ js_parser = Parser(Language(tree_sitter_javascript.language()))
+ self.parsers['js'] = js_parser
+ self.parsers['javascript'] = js_parser
+ self.parsers['jsx'] = js_parser
+ self.parsers['ts'] = js_parser
+ self.parsers['tsx'] = js_parser
+ except Exception as e:
+ logger.error(f"Error initializing parsers in Chunker: {e}")
+
+ @staticmethod
+ def _get_language_from_filename(filename: str) -> Optional[str]:
+ """Returns a canonical name for the language based on file extension."""
+ extension = os.path.splitext(filename)[1]
+ if extension == ".tsx":
+ return "tsx"
+
+ try:
+ lexer = pygments.lexers.get_lexer_for_filename(filename)
+ return lexer.name.lower()
+ except pygments.util.ClassNotFound:
+ return None
+
+ @staticmethod
+ def is_code_file(filename: str) -> bool:
+ """Checks whether the file can be parsed as code."""
+ language = StructuralChunker._get_language_from_filename(filename)
+ return language and language not in ["text only", "none"]
+
+ def chunk(self, content: str, file_path: str) -> List[Document]:
+ """Main chunking entry point."""
+ ext = file_path.split('.')[-1].lower()
+ parser = self.parsers.get(ext)
+
+ if "\0" in content:
+ logger.warning(f"Binary content detected in {file_path}, skipping chunking")
+ return []
+
+ if not parser:
+ logger.warning(f"No parser found for extension: {ext}, treating as text file")
+ # Fallback to simple text chunking for non-code files
+ return self._chunk_text_file(content, file_path)
+
+ try:
+ tree = parser.parse(bytes(content, "utf8"))
+
+ if not tree.root_node.children or tree.root_node.children[0].type == "ERROR":
+ logger.warning(f"Failed to parse code in {file_path}, falling back to text chunking")
+ return self._chunk_text_file(content, file_path)
+
+ file_metadata = {"file_path": file_path, "chunk_type": "code", "_full_content": content}
+ file_chunks = self._chunk_node(tree.root_node, content, file_metadata)
+
+ # Convert FileChunk objects to Documents
+ return [chunk.to_document() for chunk in file_chunks]
+
+ except Exception as e:
+ logger.error(f"Failed to chunk {file_path}: {e}, falling back to text chunking")
+ return self._chunk_text_file(content, file_path)
+
+ def _chunk_text_file(self, content: str, file_path: str) -> List[Document]:
+ """Fallback chunking for text files."""
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
+ splitter = RecursiveCharacterTextSplitter(
+ chunk_size=self.max_tokens * 4, # Approximate char count
+ chunk_overlap=200,
+ separators=["\n\n", "\n", " ", ""]
+ )
+ texts = splitter.split_text(content)
+ return [
+ Document(
+ page_content=f"{file_path}\n\n{text}",
+ metadata={"file_path": file_path, "chunk_type": "text"}
+ )
+ for text in texts
+ ]
+
+ def _chunk_node(self, node: Node, file_content: str, file_metadata: Dict) -> List[FileChunk]:
+ """
+ Recursively splits a node into chunks.
+ If a node is small enough, returns it as a single chunk.
+ If too large, recursively chunks its children and merges neighboring chunks when possible.
+ """
+ node_chunk = FileChunk(file_content, file_metadata, node.start_byte, node.end_byte)
+
+ # If chunk is small enough and not a module/program node, return it
+ if node_chunk.num_tokens <= self.max_tokens and node.type not in ["module", "program"]:
+ # Add metadata about the node type and name
+ chunk_metadata = {**file_metadata}
+ chunk_metadata["chunk_type"] = node.type
+ name = self._get_node_name(node, file_content)
+ if name:
+ chunk_metadata["name"] = name
+ node_chunk.file_metadata = chunk_metadata
+ return [node_chunk]
+
+ # If leaf node is too large, split it as text
+ if not node.children:
+ return self._chunk_large_text(
+ file_content[node.start_byte : node.end_byte],
+ node.start_byte,
+ file_metadata
+ )
+
+ # Recursively chunk children
+ chunks = []
+ for child in node.children:
+ chunks.extend(self._chunk_node(child, file_content, file_metadata))
+
+ # Merge neighboring chunks if their combined size doesn't exceed max_tokens
+ merged_chunks = []
+ for chunk in chunks:
+ if not merged_chunks:
+ merged_chunks.append(chunk)
+ elif merged_chunks[-1].num_tokens + chunk.num_tokens < self.max_tokens - 50:
+ # Try merging
+ merged = FileChunk(
+ file_content,
+ file_metadata,
+ merged_chunks[-1].start_byte,
+ chunk.end_byte,
+ )
+ if merged.num_tokens <= self.max_tokens:
+ merged_chunks[-1] = merged
+ else:
+ merged_chunks.append(chunk)
+ else:
+ merged_chunks.append(chunk)
+
+ # Verify all chunks are within token limit
+ for chunk in merged_chunks:
+ if chunk.num_tokens > self.max_tokens:
+ logger.warning(
+ f"Chunk size {chunk.num_tokens} exceeds max_tokens {self.max_tokens} "
+ f"for {chunk.filename} at bytes {chunk.start_byte}-{chunk.end_byte}"
+ )
+
+ return merged_chunks
+
+ def _chunk_large_text(self, text: str, start_offset: int, file_metadata: Dict) -> List[FileChunk]:
+ """Splits large text (e.g., long comments or strings) into smaller chunks."""
+ # Need full file content for FileChunk to work properly
+ file_content = file_metadata.get("_full_content", "")
+ if not file_content:
+ logger.warning("Cannot chunk large text without full file content")
+ return []
+
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
+ splitter = RecursiveCharacterTextSplitter(
+ chunk_size=self.max_tokens * 4,
+ chunk_overlap=200
+ )
+ texts = splitter.split_text(text)
+
+ chunks = []
+ current_offset = start_offset
+ for text_chunk in texts:
+ end_offset = current_offset + len(text_chunk)
+ chunk = FileChunk(
+ file_content,
+ {**file_metadata, "chunk_type": "large_text"},
+ current_offset,
+ end_offset
+ )
+ chunks.append(chunk)
+ current_offset = end_offset
+
+ return chunks
+
+ def _get_node_name(self, node: Node, content: str) -> Optional[str]:
+ """Extracts the name of a function or class node."""
+ name_node = node.child_by_field_name("name")
+ if name_node:
+ return content[name_node.start_byte:name_node.end_byte]
+ return None
diff --git a/code_chatbot/cli.py b/code_chatbot/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3720f161b48e8e5dda184bd9f15ea33ebc9e74e
--- /dev/null
+++ b/code_chatbot/cli.py
@@ -0,0 +1,298 @@
+#!/usr/bin/env python3
+"""
+🕷️ Code Crawler CLI
+Command-line interface for the Code Crawler engine.
+"""
+
+import argparse
+import os
+import sys
+import logging
+import shutil
+import json
+from dotenv import load_dotenv
+
+# Rich Imports
+from rich.console import Console
+from rich.markdown import Markdown
+from rich.panel import Panel
+from rich.prompt import Prompt
+from rich.progress import Progress, SpinnerColumn, TextColumn
+
+# Local Imports
+from .indexer import Indexer
+from .rag import ChatEngine
+from .ast_analysis import ASTGraphBuilder
+from .graph_rag import GraphEnhancedRetriever
+from .universal_ingestor import process_source
+from .agent_workflow import create_agent_graph
+
+# Configure Console
+console = Console()
+logging.basicConfig(level=logging.ERROR)
+# Suppress noisy libraries
+logging.getLogger("httpx").setLevel(logging.WARNING)
+logging.getLogger("httpcore").setLevel(logging.WARNING)
+logging.getLogger("chromadb").setLevel(logging.ERROR)
+logging.getLogger("google_genai").setLevel(logging.ERROR)
+logging.getLogger("google.genai").setLevel(logging.ERROR)
+logging.getLogger("code_chatbot.chunker").setLevel(logging.ERROR)
+
+logger = logging.getLogger("CodeCrawlerCLI")
+logger.setLevel(logging.INFO)
+
+BANNER = """
+[bold cyan] 🕷️ Code Crawler CLI 🕷️[/bold cyan]
+[dim] Index. Chat. Understand.[/dim]
+"""
+
+def setup_env():
+ load_dotenv()
+
+def print_banner():
+ console.print(Panel(BANNER, subtitle="v2.0", border_style="cyan"))
+
+def handle_index(args):
+ """
+ Handles the indexing command.
+ """
+ console.print(f"[bold blue][INFO][/bold blue] Starting indexing for source: [green]{args.source}[/green]")
+
+ # 1. Setup Environment
+ if args.provider == "gemini":
+ api_key = os.getenv("GOOGLE_API_KEY")
+ if not api_key:
+ console.print("[bold red][ERROR][/bold red] GOOGLE_API_KEY not found in .env")
+ sys.exit(1)
+ embedding_provider = "gemini"
+ embedding_api_key = api_key
+ elif args.provider == "groq":
+ api_key = os.getenv("GROQ_API_KEY")
+ embedding_api_key = os.getenv("GOOGLE_API_KEY")
+ if not api_key:
+ console.print("[bold red][ERROR][/bold red] GROQ_API_KEY not found in .env")
+ sys.exit(1)
+ if not embedding_api_key:
+ console.print("[bold red][ERROR][/bold red] GOOGLE_API_KEY (for embeddings) not found in .env")
+ sys.exit(1)
+ embedding_provider = "gemini"
+ else:
+ console.print(f"[bold red]Unknown provider:[/bold red] {args.provider}")
+ sys.exit(1)
+
+ try:
+ # 2. Extract & Ingest
+ extract_to = "data/extracted"
+ # Optional: Clean previous data
+ if args.clean and os.path.exists(extract_to):
+ console.print("[bold yellow][WARN][/bold yellow] Cleaning previous data...")
+ shutil.rmtree(extract_to)
+
+ with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console) as progress:
+ task = progress.add_task("Processing source...", total=None)
+ documents, local_path = process_source(args.source, extract_to)
+ progress.update(task, completed=True, description="[bold green]Source Processed[/bold green]")
+
+ console.print(f"[bold green][SUCCESS][/bold green] Ingested {len(documents)} documents.")
+
+ # Save metadata for Chat to find the path
+ os.makedirs("data", exist_ok=True)
+ with open("data/cli_meta.json", "w") as f:
+ json.dump({"repo_path": local_path}, f)
+
+ # 3. AST Analysis
+ with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console) as progress:
+ task = progress.add_task("Building AST Knowledge Graph...", total=None)
+ ast_builder = ASTGraphBuilder()
+ for doc in documents:
+ # doc.metadata['file_path'] is absolute
+ ast_builder.add_file(doc.metadata['file_path'], doc.page_content)
+
+ # Web sources might not create the directory
+ os.makedirs(local_path, exist_ok=True)
+ graph_path = os.path.join(local_path, "ast_graph.graphml")
+ ast_builder.save_graph(graph_path)
+ progress.update(task, completed=True, description="[bold green]AST Graph Built[/bold green]")
+
+ console.print(f"[bold green][SUCCESS][/bold green] AST Graph ready ({ast_builder.graph.number_of_nodes()} nodes).")
+
+ # 4. Vector Indexing
+ with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console) as progress:
+ task = progress.add_task(f"Indexing into {args.vector_db}...", total=None)
+ indexer = Indexer(
+ provider=embedding_provider,
+ api_key=embedding_api_key
+ )
+ # Clear old data if requested
+ if args.clean:
+ indexer.clear_collection()
+
+ indexer.index_documents(documents, vector_db_type=args.vector_db)
+ progress.update(task, completed=True, description=f"[bold green]Indexed into {args.vector_db}[/bold green]")
+
+ console.print(f"[bold green][SUCCESS][/bold green] Indexing Complete! You can now run `code-crawler chat`.")
+
+ except Exception as e:
+ console.print(f"[bold red][ERROR][/bold red] Indexing failed: {e}")
+ # import traceback
+ # traceback.print_exc()
+
+def handle_chat(args):
+ """
+ Handles the chat command.
+ """
+ console.print(f"[bold blue][INFO][/bold blue] Initializing Chat Engine ({args.provider})...")
+
+ # Setup Env & Keys
+ if args.provider == "gemini":
+ api_key = os.getenv("GOOGLE_API_KEY")
+ embedding_api_key = api_key
+ embedding_provider = "gemini"
+ model_name = "gemini-2.5-flash"
+ llm_provider_lib = "google_genai"
+ elif args.provider == "groq":
+ api_key = os.getenv("GROQ_API_KEY")
+ embedding_api_key = os.getenv("GOOGLE_API_KEY")
+ embedding_provider = "gemini"
+ model_name = "llama-3.3-70b-versatile"
+ llm_provider_lib = "groq"
+
+ if not api_key:
+ console.print("[bold red][ERROR][/bold red] API Keys missing. Check .env")
+ sys.exit(1)
+
+ try:
+ # Load Resources
+ meta_file = "data/cli_meta.json"
+ if os.path.exists(meta_file):
+ with open(meta_file, "r") as f:
+ meta = json.load(f)
+ local_path = meta.get("repo_path")
+ else:
+ # Fallback Heuristic
+ extract_root = "data/extracted"
+ if not os.path.exists(extract_root):
+ console.print("[bold red][ERROR][/bold red] No index info found. Run 'code-crawler index' first.")
+ sys.exit(1)
+
+ subdirs = [f.path for f in os.scandir(extract_root) if f.is_dir()]
+ if not subdirs:
+ local_path = extract_root
+ else:
+ subdirs.sort(key=lambda x: os.path.getmtime(x), reverse=True)
+ local_path = subdirs[0]
+
+ if not local_path or not os.path.exists(local_path):
+ console.print(f"[bold red][ERROR][/bold red] Codebase path not found: {local_path}")
+ sys.exit(1)
+
+ console.print(f"[dim]Using codebase at: {local_path}[/dim]")
+
+ # Initialize Components
+ with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console) as progress:
+ task = progress.add_task("Loading resources...", total=None)
+
+ indexer = Indexer(provider=embedding_provider, api_key=embedding_api_key)
+ base_retriever = indexer.get_retriever(vector_db_type=args.vector_db)
+
+ graph_retriever = GraphEnhancedRetriever(
+ base_retriever=base_retriever,
+ repo_dir=local_path
+ )
+
+ repo_files = []
+ for root, _, files in os.walk(local_path):
+ for file in files:
+ repo_files.append(os.path.join(root, file))
+
+ progress.update(task, completed=True, description="[bold green]Resources Loaded[/bold green]")
+
+ # Initialize ChatEngine
+ if args.agent:
+ console.print("[bold purple]🤖 Agent Mode Enabled[/bold purple]")
+
+ chat_engine = ChatEngine(
+ retriever=graph_retriever,
+ provider=args.provider,
+ model_name=model_name,
+ api_key=api_key,
+ repo_files=repo_files,
+ repo_name=os.path.basename(local_path),
+ use_agent=args.agent,
+ repo_dir=local_path
+ )
+
+ console.print("\n[bold green]Ready![/bold green] chat initialized. Type 'exit' to quit.\n")
+
+ while True:
+ try:
+ query = Prompt.ask("[bold cyan]User[/bold cyan]")
+ if query.strip().lower() in ['exit', 'quit', ':q']:
+ break
+
+ if not query.strip():
+ continue
+
+ console.print("[dim]🕷️ Thinking...[/dim]")
+
+ # Unified Chat Call (Handles Agent & Standard + Fallback)
+ response = chat_engine.chat(query)
+
+ if isinstance(response, tuple):
+ answer, sources = response
+ else:
+ answer = response
+ sources = []
+
+ # Render Response
+ console.print(Panel(Markdown(answer), title="Spider", border_style="magenta", expand=False))
+
+ if sources:
+ console.print("[dim]Sources:[/dim]")
+ seen = set()
+ for s in sources:
+ fp = s.get('file_path', 'unknown')
+ if fp not in seen:
+ console.print(f" - [underline]{os.path.basename(fp)}[/underline]")
+ seen.add(fp)
+ console.print("")
+
+ except KeyboardInterrupt:
+ break
+ except Exception as e:
+ console.print(f"[bold red][ERROR][/bold red] {e}")
+
+ except Exception as e:
+ console.print(f"[bold red][ERROR][/bold red] Chat failed to start: {e}")
+ # import traceback
+ # traceback.print_exc()
+
+def main():
+ setup_env()
+ print_banner()
+
+ parser = argparse.ArgumentParser(description="Code Crawler CLI")
+ subparsers = parser.add_subparsers(dest="command", required=True)
+
+ # Index Command
+ index_parser = subparsers.add_parser("index", help="Index a codebase (ZIP, URL, or Path)")
+ index_parser.add_argument("--source", "-s", required=True, help="Path to ZIP, Folder, or GitHub URL")
+ index_parser.add_argument("--provider", "-p", default="gemini", choices=["gemini", "groq"], help="LLM Provider")
+ index_parser.add_argument("--vector-db", "-v", default="chroma", choices=["chroma", "faiss"], help="Vector Database")
+ index_parser.add_argument("--clean", action="store_true", help="Clean previous index before running")
+
+ # Chat Command
+ chat_parser = subparsers.add_parser("chat", help="Chat with the indexed codebase")
+ chat_parser.add_argument("--provider", "-p", default="gemini", choices=["gemini", "groq"], help="LLM Provider")
+ chat_parser.add_argument("--vector-db", "-v", default="chroma", choices=["chroma", "faiss"], help="Vector Database type used during index")
+ chat_parser.add_argument("--agent", "-a", action="store_true", help="Enable Agentic Reasoning (LangGraph)")
+
+ args = parser.parse_args()
+
+ if args.command == "index":
+ handle_index(args)
+ elif args.command == "chat":
+ handle_chat(args)
+
+if __name__ == "__main__":
+ main()
diff --git a/code_chatbot/code_symbols.py b/code_chatbot/code_symbols.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b53ac6b474aab98bc0404e4c40de70ee2f1cee
--- /dev/null
+++ b/code_chatbot/code_symbols.py
@@ -0,0 +1,88 @@
+"""Utilities to extract code symbols (class and method names) from code files."""
+
+import logging
+from typing import List, Tuple, Optional
+from tree_sitter import Node
+
+from code_chatbot.chunker import StructuralChunker
+
+logger = logging.getLogger(__name__)
+
+
+def _extract_classes_and_methods(node: Node, acc: List[Tuple[Optional[str], Optional[str]]], parent_class: Optional[str] = None, content: str = ""):
+ """Extracts classes and methods from a tree-sitter node and places them in the `acc` accumulator.
+
+ Args:
+ node: The tree-sitter node to traverse
+ acc: Accumulator list to store (class_name, method_name) tuples
+ parent_class: Name of the parent class (if any)
+ content: The file content as string (for extracting names)
+ """
+ if node.type in ["class_definition", "class_declaration"]:
+ class_name_node = node.child_by_field_name("name")
+ if class_name_node:
+ class_name = content[class_name_node.start_byte:class_name_node.end_byte]
+ if class_name:
+ acc.append((class_name, None))
+ # Recursively process children with this class as parent
+ for child in node.children:
+ _extract_classes_and_methods(child, acc, class_name, content)
+ return
+ elif node.type in ["function_definition", "method_definition"]:
+ function_name_node = node.child_by_field_name("name")
+ if function_name_node:
+ method_name = content[function_name_node.start_byte:function_name_node.end_byte]
+ if method_name:
+ acc.append((parent_class, method_name))
+ # Don't go deeper into method bodies (we're not extracting nested functions)
+ return
+ else:
+ # Recursively process children
+ for child in node.children:
+ _extract_classes_and_methods(child, acc, parent_class, content)
+
+
+def get_code_symbols(file_path: str, content: str) -> List[Tuple[Optional[str], Optional[str]]]:
+ """Extracts code symbols from a file.
+
+ Code symbols are tuples of the form (class_name, method_name).
+ For classes, method_name is None.
+ For methods that do not belong to a class, class_name is None.
+
+ Args:
+ file_path: Path to the file
+ content: Content of the file as a string
+
+ Returns:
+ List of (class_name, method_name) tuples
+ """
+ if not StructuralChunker.is_code_file(file_path):
+ return []
+
+ if not content:
+ return []
+
+ logger.debug(f"Extracting code symbols from {file_path}")
+
+ # Try to parse the file using the chunker's parsing logic
+ try:
+ ext = file_path.split('.')[-1].lower()
+ chunker = StructuralChunker()
+
+ if ext not in chunker.parsers:
+ return []
+
+ parser = chunker.parsers[ext]
+ tree = parser.parse(bytes(content, "utf8"))
+
+ if not tree or not tree.root_node.children:
+ return []
+
+ classes_and_methods = []
+ _extract_classes_and_methods(tree.root_node, classes_and_methods, None, content)
+ return classes_and_methods
+
+ except Exception as e:
+ logger.warning(f"Failed to extract code symbols from {file_path}: {e}")
+ return []
+
diff --git a/code_chatbot/graph_rag.py b/code_chatbot/graph_rag.py
new file mode 100644
index 0000000000000000000000000000000000000000..35d34892367cc4dff40ffffc4fb5637abf28e57a
--- /dev/null
+++ b/code_chatbot/graph_rag.py
@@ -0,0 +1,111 @@
+import os
+import networkx as nx
+import logging
+from typing import List, Optional, Any
+from langchain_core.retrievers import BaseRetriever
+from langchain_core.documents import Document
+
+logger = logging.getLogger(__name__)
+
+class GraphEnhancedRetriever(BaseRetriever):
+ """Wraps a base retriever and augments results using an AST knowledge graph."""
+
+ base_retriever: BaseRetriever
+ graph: Optional[Any] = None
+ repo_dir: str
+
+ def __init__(self, base_retriever: BaseRetriever, repo_dir: str, **kwargs):
+ # Initialize Pydantic fields
+ super().__init__(base_retriever=base_retriever, repo_dir=repo_dir, **kwargs)
+ self.graph = self._load_graph()
+
+ def _load_graph(self):
+ graph_path = os.path.join(self.repo_dir, "ast_graph.graphml")
+ if os.path.exists(graph_path):
+ try:
+ logger.info(f"Loading AST Graph from {graph_path}")
+ return nx.read_graphml(graph_path)
+ except Exception as e:
+ logger.error(f"Failed to load AST graph: {e}")
+ else:
+ logger.warning(f"No AST graph found at {graph_path}")
+ return None
+
+ def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
+ # 1. Standard Retrieval
+ logger.info(f"GraphEnhancedRetriever: Querying base retriever with: '{query}'")
+ docs = self.base_retriever.invoke(query)
+ logger.info(f"GraphEnhancedRetriever: Base retriever returned {len(docs)} documents")
+
+ if not self.graph:
+ logger.warning("No AST graph available for enhancement")
+ return docs
+
+ # 2. Graph Expansion
+ augmented_docs = list(docs)
+ seen_files = {d.metadata.get("file_path") for d in docs}
+
+ # We also want to see what files are already in the docs to avoid duplicating content
+ # But here we are looking for RELATED files that might not be in the vector search results.
+
+ for doc in docs:
+ file_path = doc.metadata.get("file_path")
+ if not file_path: continue
+
+ # Normalize path if needed (relative vs absolute)
+ # The graph was built with paths relative to extracting location or absolute?
+ # We need to ensure consistency.
+ # In ingestor we use: rel_path for source, but file_path for absolute.
+ # In ast_analysis we used file_path passed to add_file.
+ # We need to verify how we call add_file in app.py.
+
+ # Let's try to find the node in the graph
+ target_node = None
+ if file_path in self.graph:
+ target_node = file_path
+ else:
+ # Try checking if just filename match
+ # Or try absolute path match (depends on how we built the graph)
+ pass
+
+ if target_node and target_node in self.graph:
+ neighbors = list(self.graph.neighbors(target_node))
+ for neighbor in neighbors:
+ # Neighbor could be a file or a symbol (file::symbol)
+ if "::" in neighbor:
+ neighbor_file = neighbor.split("::")[0]
+ else:
+ neighbor_file = neighbor
+
+ # Skip if we've already seen this file
+ if neighbor_file in seen_files:
+ continue
+
+ # Check if file exists (handle both relative and absolute paths)
+ if os.path.exists(neighbor_file):
+ try:
+ # Limit expansion to small files to avoid context overflow
+ if os.path.getsize(neighbor_file) < 20000: # 20KB limit
+ with open(neighbor_file, "r", errors='ignore') as f:
+ content = f.read()
+
+ # Get relationship type from edge
+ edge_data = self.graph.get_edge_data(target_node, neighbor, {})
+ relation = edge_data.get("relation", "related") if edge_data else "related"
+
+ new_doc = Document(
+ page_content=f"--- Graph Context ({relation} from {os.path.basename(file_path)}) ---\n{content}",
+ metadata={
+ "file_path": neighbor_file,
+ "source": "ast_graph",
+ "relation": relation,
+ "related_to": file_path
+ }
+ )
+ augmented_docs.append(new_doc)
+ seen_files.add(neighbor_file)
+ logger.debug(f"Added graph-related file: {neighbor_file} (relation: {relation})")
+ except Exception as e:
+ logger.warning(f"Failed to add graph-related file {neighbor_file}: {e}")
+
+ return augmented_docs
diff --git a/code_chatbot/indexer.py b/code_chatbot/indexer.py
new file mode 100644
index 0000000000000000000000000000000000000000..585574bdc35d1e4199bf2fb810f3368d197f21c7
--- /dev/null
+++ b/code_chatbot/indexer.py
@@ -0,0 +1,237 @@
+import os
+from typing import List
+from langchain_core.documents import Document
+from langchain_community.vectorstores import Chroma
+from langchain_google_genai import GoogleGenerativeAIEmbeddings
+from code_chatbot.chunker import StructuralChunker
+import shutil
+import logging
+
+logger = logging.getLogger(__name__)
+
+# Global ChromaDB client cache to avoid "different settings" error
+_chroma_clients = {}
+
+def get_chroma_client(persist_directory: str):
+ """Get or create a shared ChromaDB client for a given path."""
+ global _chroma_clients
+
+ if persist_directory not in _chroma_clients:
+ import chromadb
+ from chromadb.config import Settings
+
+ _chroma_clients[persist_directory] = chromadb.PersistentClient(
+ path=persist_directory,
+ settings=Settings(
+ anonymized_telemetry=False,
+ allow_reset=True
+ )
+ )
+
+ return _chroma_clients[persist_directory]
+
+
+class Indexer:
+ """
+ Indexes code files into a Vector Database.
+ Now uses StructuralChunker for semantic splitting.
+ """
+ def __init__(self, persist_directory: str = "chroma_db", embedding_function=None, provider: str = "gemini", api_key: str = None):
+ self.persist_directory = persist_directory
+ self.provider = provider
+
+ # Initialize Structural Chunker
+ self.chunker = StructuralChunker()
+
+ # Setup Embeddings (only Gemini supported)
+ if embedding_function:
+ self.embedding_function = embedding_function
+ else:
+ if provider == "gemini":
+ api_key = api_key or os.getenv("GOOGLE_API_KEY")
+ if not api_key:
+ raise ValueError("Google API Key is required for Gemini Embeddings")
+ self.embedding_function = GoogleGenerativeAIEmbeddings(
+ model="models/text-embedding-004",
+ google_api_key=api_key
+ )
+ else:
+ raise ValueError(f"Unsupported embedding provider: {provider}. Only 'gemini' is supported.")
+
+ def clear_collection(self, collection_name: str = "codebase"):
+ """
+ Safely clears a collection from the vector database.
+ """
+ try:
+ client = get_chroma_client(self.persist_directory)
+ try:
+ client.delete_collection(collection_name)
+ logger.info(f"Deleted collection '{collection_name}'")
+ except ValueError:
+ # Collection doesn't exist
+ pass
+ except Exception as e:
+ logger.warning(f"Failed to clear collection: {e}")
+
+
+ def index_documents(self, documents: List[Document], collection_name: str = "codebase", vector_db_type: str = "chroma"):
+ """
+ Splits documents structurally and generates embeddings.
+ Supports 'chroma' and 'faiss'.
+ """
+ if not documents:
+ logger.warning("No documents to index.")
+ return
+
+ all_chunks = []
+ for doc in documents:
+ # chunker.chunk returns List[Document]
+ file_chunks = self.chunker.chunk(doc.page_content, doc.metadata["file_path"])
+ all_chunks.extend(file_chunks)
+
+ if not all_chunks:
+ pass
+
+ # Create/Update Vector # Filter out complex metadata and potential None values that slip through
+ from langchain_community.vectorstores.utils import filter_complex_metadata
+
+ # Ensure metadata is clean
+ for doc in all_chunks:
+ # Double check for None values in metadata values and remove them
+ doc.metadata = {k:v for k,v in doc.metadata.items() if v is not None}
+
+ all_chunks = filter_complex_metadata(all_chunks)
+
+ if vector_db_type == "chroma":
+ # Use shared client to avoid "different settings" error
+ chroma_client = get_chroma_client(self.persist_directory)
+
+ vectordb = Chroma(
+ client=chroma_client,
+ embedding_function=self.embedding_function,
+ collection_name=collection_name
+ )
+ elif vector_db_type == "faiss":
+ from langchain_community.vectorstores import FAISS
+ # FAISS is in-memory by default, we'll save it to disk later
+ vectordb = None # We build it in the loop
+ elif vector_db_type == "qdrant":
+ vectordb = None # Built in bulk later
+ else:
+ raise ValueError(f"Unsupported Vector DB: {vector_db_type}")
+
+ # Batch processing
+ batch_size = 100
+ total_chunks = len(all_chunks)
+
+ logger.info(f"Indexing {total_chunks} chunks in batches of {batch_size}...")
+
+ from tqdm import tqdm
+ import time
+
+ # FAISS handles batching poorly if we want to save incrementally, so we build a list first for FAISS or use from_documents
+ if vector_db_type == "faiss":
+ from langchain_community.vectorstores import FAISS
+ # For FAISS, it's faster to just do it all at once or in big batches
+ vectordb = FAISS.from_documents(all_chunks, self.embedding_function)
+ vectordb.save_local(folder_path=self.persist_directory, index_name=collection_name)
+ return vectordb
+
+ elif vector_db_type == "qdrant":
+ from langchain_qdrant import QdrantVectorStore
+ from qdrant_client import QdrantClient
+
+ url = os.getenv("QDRANT_URL")
+ api_key = os.getenv("QDRANT_API_KEY")
+
+ if not url:
+ # Fallback to local
+ logger.info("No QDRANT_URL found, using local Qdrant memory/disk")
+ location = ":memory:" # or path
+
+ vectordb = QdrantVectorStore.from_documents(
+ documents=all_chunks,
+ embedding=self.embedding_function,
+ url=url,
+ api_key=api_key,
+ collection_name=collection_name,
+ prefer_grpc=True
+ )
+ return vectordb
+
+ # Loop for Chroma (existing logic)
+ for i in range(0, total_chunks, batch_size):
+ batch = all_chunks[i:i + batch_size]
+ try:
+ vectordb.add_documents(documents=batch)
+ logger.info(f"Indexed batch {i // batch_size + 1}/{(total_chunks + batch_size - 1) // batch_size}")
+ # Optional: slight delay to be nice to API
+ time.sleep(0.5)
+ except Exception as e:
+ logger.error(f"Error indexing batch {i}: {e}")
+ # Try one by one if batch fails??
+ continue
+
+
+ # PersistentClient auto-persists
+ logger.info(f"Indexed {len(all_chunks)} chunks into collection '{collection_name}' at {self.persist_directory}")
+ return vectordb
+
+ def get_retriever(self, collection_name: str = "codebase", k: int = 10, vector_db_type: str = "chroma"):
+ """Get a retriever for the specified collection. Default k=10 for comprehensive results."""
+ logger.info(f"Creating retriever for collection '{collection_name}' from {self.persist_directory}")
+
+ if vector_db_type == "chroma":
+ # Use shared client to avoid "different settings" error
+ chroma_client = get_chroma_client(self.persist_directory)
+
+ # Load existing vector store
+ vector_store = Chroma(
+ client=chroma_client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ )
+
+ # Log collection info
+ try:
+ collection = vector_store._collection
+ count = collection.count()
+ logger.info(f"Collection '{collection_name}' has {count} documents")
+ except Exception as e:
+ logger.warning(f"Could not get collection count: {e}")
+
+ elif vector_db_type == "faiss":
+ from langchain_community.vectorstores import FAISS
+ try:
+ vector_store = FAISS.load_local(
+ folder_path=self.persist_directory,
+ embeddings=self.embedding_function,
+ index_name=collection_name,
+ allow_dangerous_deserialization=True # Codebase trust assumed for local use
+ )
+ logger.info(f"Loaded FAISS index from {self.persist_directory}")
+ except Exception as e:
+ logger.error(f"Failed to load FAISS index: {e}")
+ # Create empty store if failed? Or raise?
+ raise e
+ elif vector_db_type == "qdrant":
+ from langchain_qdrant import QdrantVectorStore
+
+ url = os.getenv("QDRANT_URL")
+ api_key = os.getenv("QDRANT_API_KEY")
+
+ vector_store = QdrantVectorStore(
+ client=None, # It will create one from url/api_key
+ collection_name=collection_name,
+ embedding=self.embedding_function,
+ url=url,
+ api_key=api_key,
+ )
+ logger.info(f"Connected to Qdrant at {url}")
+
+ else:
+ raise ValueError(f"Unsupported Vector DB: {vector_db_type}")
+
+ retriever = vector_store.as_retriever(search_kwargs={"k": k})
+ logger.info(f"Retriever created with k={k}")
+ return retriever
diff --git a/code_chatbot/indexing_progress.py b/code_chatbot/indexing_progress.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdb407428acab3f96654ddf3dd00fdb944b73429
--- /dev/null
+++ b/code_chatbot/indexing_progress.py
@@ -0,0 +1,255 @@
+"""
+Optimized indexing with progress tracking for Streamlit UI
+"""
+
+import os
+import time
+import shutil
+import logging
+from typing import List, Tuple
+from langchain_core.documents import Document
+import streamlit as st
+
+logger = logging.getLogger(__name__)
+
+def index_with_progress(
+ source_input: str,
+ source_type: str,
+ provider: str,
+ embedding_provider: str,
+ embedding_api_key: str,
+ vector_db_type: str,
+ use_agent: bool,
+ api_key: str,
+ gemini_model: str = None
+) -> Tuple[object, bool]:
+ """
+ Index a codebase with detailed progress tracking.
+ Returns (chat_engine, success)
+ """
+ from code_chatbot.universal_ingestor import process_source
+ from code_chatbot.ast_analysis import ASTGraphBuilder
+ from code_chatbot.indexer import Indexer
+ from code_chatbot.graph_rag import GraphEnhancedRetriever
+ from code_chatbot.rag import ChatEngine
+ from code_chatbot.chunker import StructuralChunker
+ from langchain_community.vectorstores import Chroma, FAISS
+ from langchain_community.vectorstores.utils import filter_complex_metadata
+
+ # Create progress tracking
+ progress_bar = st.progress(0)
+ status_text = st.empty()
+
+ try:
+ # Stage 1: Extract & Ingest (0-20%)
+ status_text.text("📦 Stage 1/4: Extracting and ingesting files...")
+ progress_bar.progress(0.05)
+
+ extract_to = os.path.join("data", "extracted")
+
+ if os.path.exists(extract_to):
+ status_text.text("🧹 Cleaning previous data...")
+ shutil.rmtree(extract_to)
+
+ progress_bar.progress(0.10)
+
+ documents, local_path = process_source(source_input, extract_to)
+ progress_bar.progress(0.20)
+ status_text.text(f"✅ Stage 1 Complete: Ingested {len(documents)} files")
+
+ # Stage 2: AST Analysis (20-40%)
+ status_text.text("🧠 Stage 2/4: Building AST Knowledge Graph...")
+ progress_bar.progress(0.25)
+
+ ast_builder = ASTGraphBuilder()
+ total_docs = len(documents)
+
+ for idx, doc in enumerate(documents):
+ if idx % 10 == 0:
+ progress = 0.25 + (0.15 * (idx / total_docs))
+ progress_bar.progress(progress)
+ status_text.text(f"🧠 Stage 2/4: Analyzing file {idx+1}/{total_docs}...")
+
+ ast_builder.add_file(doc.metadata['file_path'], doc.page_content)
+
+ os.makedirs(local_path, exist_ok=True)
+ graph_path = os.path.join(local_path, "ast_graph.graphml")
+ ast_builder.save_graph(graph_path)
+
+ progress_bar.progress(0.40)
+ status_text.text(f"✅ Stage 2 Complete: Graph with {ast_builder.graph.number_of_nodes()} nodes")
+
+ # Stage 3: Chunking (40-50%)
+ status_text.text("✂️ Stage 3/4: Chunking documents...")
+ progress_bar.progress(0.42)
+
+ indexer = Indexer(
+ provider=embedding_provider,
+ api_key=embedding_api_key
+ )
+
+ indexer.clear_collection(collection_name="codebase")
+ progress_bar.progress(0.45)
+
+ chunker = StructuralChunker()
+ all_chunks = []
+
+ for idx, doc in enumerate(documents):
+ if idx % 5 == 0:
+ progress = 0.45 + (0.05 * (idx / total_docs))
+ progress_bar.progress(progress)
+ status_text.text(f"✂️ Stage 3/4: Chunking file {idx+1}/{total_docs}...")
+
+ file_chunks = chunker.chunk(doc.page_content, doc.metadata["file_path"])
+ all_chunks.extend(file_chunks)
+
+ progress_bar.progress(0.50)
+ status_text.text(f"✅ Stage 3 Complete: {len(all_chunks)} chunks from {len(documents)} files")
+
+ # Stage 4: Generate Embeddings & Index (50-100%)
+ status_text.text(f"🔮 Stage 4/4: Generating embeddings for {len(all_chunks)} chunks...")
+ if len(all_chunks) > 500:
+ status_text.text("⚠️ Large codebase detected. This may take 2-5 minutes...")
+ progress_bar.progress(0.55)
+
+ # Clean metadata
+ for doc in all_chunks:
+ doc.metadata = {k:v for k,v in doc.metadata.items() if v is not None}
+ all_chunks = filter_complex_metadata(all_chunks)
+
+ # Index with progress
+ batch_size = 100
+ total_chunks = len(all_chunks)
+
+ if vector_db_type == "faiss":
+ status_text.text(f"🔮 Generating {total_chunks} embeddings (FAISS - one batch)...")
+ vectordb = FAISS.from_documents(all_chunks, indexer.embedding_function)
+ vectordb.save_local(folder_path=indexer.persist_directory, index_name="codebase")
+ progress_bar.progress(1.0)
+
+ elif vector_db_type == "qdrant":
+ from langchain_qdrant import QdrantVectorStore
+ status_text.text(f"🔮 Generating {total_chunks} embeddings (Qdrant)...")
+
+ url = os.getenv("QDRANT_URL")
+ api_key_qdrant = os.getenv("QDRANT_API_KEY")
+
+ vectordb = QdrantVectorStore.from_documents(
+ documents=all_chunks,
+ embedding=indexer.embedding_function,
+ url=url,
+ api_key=api_key_qdrant,
+ collection_name="codebase",
+ prefer_grpc=True
+ )
+ progress_bar.progress(1.0)
+
+ else: # Chroma
+ from code_chatbot.indexer import get_chroma_client
+ chroma_client = get_chroma_client(indexer.persist_directory)
+
+ vectordb = Chroma(
+ client=chroma_client,
+ embedding_function=indexer.embedding_function,
+ collection_name="codebase"
+ )
+
+ for i in range(0, total_chunks, batch_size):
+ batch = all_chunks[i:i + batch_size]
+ batch_num = i // batch_size + 1
+ total_batches = (total_chunks + batch_size - 1) // batch_size
+
+ progress = 0.55 + (0.45 * (i / total_chunks))
+ progress_bar.progress(progress)
+ status_text.text(f"🔮 Batch {batch_num}/{total_batches} ({i+batch_size}/{total_chunks} chunks)")
+
+ # Retry logic for rate limits
+ max_retries = 3
+ retry_count = 0
+ success = False
+
+ while retry_count < max_retries and not success:
+ try:
+ vectordb.add_documents(documents=batch)
+ time.sleep(0.2) # Rate limit protection
+ success = True
+ except Exception as e:
+ error_msg = str(e).lower()
+
+ # Check if it's a rate limit error
+ if "rate" in error_msg or "quota" in error_msg or "429" in error_msg or "resource_exhausted" in error_msg:
+ retry_count += 1
+ if retry_count < max_retries:
+ wait_time = 30 * retry_count # 30s, 60s, 90s
+ status_text.text(f"⚠️ Rate limit hit. Waiting {wait_time}s before retry {retry_count}/{max_retries}...")
+ st.warning(f"⏰ Embedding API rate limit. Pausing {wait_time}s... (Retry {retry_count}/{max_retries})")
+
+ # Show countdown
+ for remaining in range(wait_time, 0, -5):
+ status_text.text(f"⏰ Waiting {remaining}s for rate limit to reset...")
+ time.sleep(5)
+
+ status_text.text(f"🔄 Retrying batch {batch_num}/{total_batches}...")
+ else:
+ st.error(f"❌ Failed after {max_retries} retries. Wait 5-10 minutes and try again.")
+ raise Exception(f"Rate limit exceeded after {max_retries} retries. Please wait and try again.")
+ else:
+ # Not a rate limit error, just warn and continue
+ st.warning(f"⚠️ Batch {batch_num} error: {str(e)[:50]}...")
+ break # Skip this batch and continue
+
+ # PersistentClient auto-persists, no need to call vectordb.persist()
+ progress_bar.progress(1.0)
+
+ status_text.text(f"✅ Stage 4 Complete: Indexed {len(all_chunks)} chunks!")
+
+ # Stage 5: Initialize Chat Engine
+ status_text.text("🚀 Initializing chat engine...")
+
+ base_retriever = indexer.get_retriever(vector_db_type=vector_db_type)
+
+ graph_retriever = GraphEnhancedRetriever(
+ base_retriever=base_retriever,
+ repo_dir=local_path
+ )
+
+ repo_files = list(set([doc.metadata['file_path'] for doc in documents]))
+
+ # Use selected model or fallback to defaults
+ model_name = None
+ if provider == "gemini":
+ model_name = gemini_model if gemini_model else "gemini-2.0-flash-exp"
+ elif provider == "groq":
+ model_name = "llama-3.3-70b-versatile"
+
+ chat_engine = ChatEngine(
+ retriever=graph_retriever,
+ provider=provider,
+ model_name=model_name,
+ api_key=api_key,
+ repo_files=repo_files,
+ repo_name=os.path.basename(source_input) if source_input else "Codebase",
+ use_agent=use_agent,
+ repo_dir=local_path
+ )
+
+ # Final success
+ st.success(f"""
+ 🎉 **Indexing Complete!**
+ - Files: {len(documents)}
+ - Chunks: {len(all_chunks)}
+ - Graph Nodes: {ast_builder.graph.number_of_nodes()}
+ - Ready to chat!
+ """)
+
+ progress_bar.empty()
+ status_text.empty()
+
+ return chat_engine, True
+
+ except Exception as e:
+ st.error(f"❌ Error during indexing: {e}")
+ logger.error(f"Indexing failed: {e}", exc_info=True)
+ progress_bar.empty()
+ status_text.empty()
+ return None, False
diff --git a/code_chatbot/ingestor.py b/code_chatbot/ingestor.py
new file mode 100644
index 0000000000000000000000000000000000000000..268850baa59c6c4c3b74c5a4dc29b903ff32bf48
--- /dev/null
+++ b/code_chatbot/ingestor.py
@@ -0,0 +1,103 @@
+import os
+import zipfile
+import tempfile
+import shutil
+from typing import List, Optional
+from langchain_core.documents import Document
+import logging
+
+# Configure logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# Extensions to ignore (binaries, images, etc.)
+IGNORE_EXTENSIONS = {
+ '.pyc', '.git', '.github', '.idea', '.vscode', '.DS_Store',
+ '.png', '.jpg', '.jpeg', '.gif', '.ico', '.svg',
+ '.mp4', '.mov', '.mp3', '.wav',
+ '.zip', '.tar', '.gz', '.pkl', '.bin', '.exe', '.dll', '.so', '.dylib',
+ '.pdf', '.docx', '.xlsx', '.pptx'
+}
+
+# Directories to ignore
+IGNORE_DIRS = {
+ '__pycache__', '.git', '.github', '.idea', '.vscode', 'node_modules', 'venv', '.venv', 'env', '.env', 'dist', 'build', 'target'
+}
+
+def is_text_file(file_path: str) -> bool:
+ """Check if a file is likely a text file based on extension and content."""
+ _, ext = os.path.splitext(file_path)
+ if ext.lower() in IGNORE_EXTENSIONS:
+ return False
+
+ try:
+ with open(file_path, 'r', encoding='utf-8') as f:
+ f.read(1024)
+ return True
+ except UnicodeDecodeError:
+ return False
+ except Exception:
+ return False
+
+def process_zip(zip_path: str, extract_to: str) -> List[Document]:
+ """
+ Extracts a ZIP file and returns a list of LangChain Documents.
+
+ Args:
+ zip_path: Path to the uploaded ZIP file.
+ extract_to: Directory to extract files to.
+
+ Returns:
+ List[Document]: List of documents with content and metadata.
+ """
+ documents = []
+
+ if not os.path.exists(extract_to):
+ os.makedirs(extract_to)
+
+ try:
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
+ zip_ref.extractall(extract_to)
+
+ logger.info(f"Extracted {zip_path} to {extract_to}")
+
+ # Walk through the extracted files
+ for root, dirs, files in os.walk(extract_to):
+ # Modify dirs in-place to skip ignored directories
+ dirs[:] = [d for d in dirs if d not in IGNORE_DIRS and not d.startswith('.')]
+
+ for file in files:
+ if file.startswith('.'):
+ continue
+
+ file_path = os.path.join(root, file)
+
+ if is_text_file(file_path):
+ try:
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
+ content = f.read()
+
+ # Create relative path for metadata
+ rel_path = os.path.relpath(file_path, extract_to)
+
+ doc = Document(
+ page_content=content,
+ metadata={
+ "source": rel_path,
+ "file_path": file_path,
+ "file_name": file
+ }
+ )
+ documents.append(doc)
+ except Exception as e:
+ logger.warning(f"Failed to read {file_path}: {e}")
+
+ logger.info(f"Processed {len(documents)} documents from {zip_path}")
+ return documents
+
+ except zipfile.BadZipFile:
+ logger.error(f"Invalid ZIP file: {zip_path}")
+ raise ValueError("The provided file is not a valid ZIP archive.")
+ except Exception as e:
+ logger.error(f"Error processing ZIP: {e}")
+ raise e
diff --git a/code_chatbot/llm_retriever.py b/code_chatbot/llm_retriever.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f273192820b93c85dc5c7205a5553a822694139
--- /dev/null
+++ b/code_chatbot/llm_retriever.py
@@ -0,0 +1,166 @@
+import logging
+import os
+from typing import Any, Dict, List, Optional, Set
+from anytree import Node, RenderTree
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+from langchain_core.language_models import BaseChatModel
+from langchain_core.messages import SystemMessage, HumanMessage
+from pydantic import PrivateAttr
+import Levenshtein
+
+logger = logging.getLogger(__name__)
+
+class LLMRetriever(BaseRetriever):
+ """
+ Retriever that uses an LLM to select relevant files from the project structure.
+ Adapted from generic Sage implementation to work with LangChain models.
+ """
+
+ llm: BaseChatModel
+ repo_files: List[str]
+ top_k: int = 5
+ repo_structure: str = ""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ # Use object.__setattr__ to avoid pydantic validation errors if frozen
+ # But since we made it a field, we can just set it OR pass it in kwargs if calculated before.
+ # Better: calculate it here and set it.
+ structure = self._build_repo_structure(self.repo_files)
+ self.repo_structure = structure
+
+ def _build_repo_structure(self, files: List[str]) -> str:
+ """Builds a visual tree structure of the repository."""
+ # Build tree
+ root = Node("root")
+ nodes = {"": root}
+
+ for file_path in files:
+ parts = file_path.strip("/").split("/")
+ current_path = ""
+ parent = root
+
+ for part in parts:
+ current_path = f"{current_path}/{part}" if current_path else part
+ if current_path not in nodes:
+ nodes[current_path] = Node(part, parent=parent)
+ parent = nodes[current_path]
+
+ # Render tree
+ render = ""
+ for pre, _, node in RenderTree(root):
+ if node.name == "root": continue
+ # Simplify characters for token efficiency
+ line = f"{pre}{node.name}"
+ line = line.replace("└", " ").replace("├", " ").replace("│", " ").replace("─", " ")
+ render += line + "\n"
+
+ return render
+
+ def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
+ """Retrieve relevant documents for a given query."""
+ try:
+ logger.info("LLMRetriever: Asking LLM to select files...")
+ filenames = self._ask_llm_to_retrieve(query)
+ logger.info(f"LLMRetriever: Selected {len(filenames)} files: {filenames}")
+
+ documents = []
+ for filename in filenames:
+ # We expect the caller to handle reading the actual content if needed,
+ # or we return a Document with just metadata if we don't have access to the file system here.
+ # Ideally, we should have access to read the file.
+ # Let's assume we can read if it is a local path (which it should be in this app).
+
+ # Check if we can find the absolute path?
+ # The repo_files passed in might be relative paths or absolute.
+ # We will assume they are paths we can open.
+
+ try:
+ # If repo_files are absolute, great. If relative, we might need a base_dir.
+ # For now, let's assume the passed repo_files are valid paths to read.
+ if os.path.exists(filename):
+ with open(filename, "r", errors='ignore') as f:
+ content = f.read()
+ documents.append(Document(
+ page_content=content,
+ metadata={"file_path": filename, "source": "llm_retriever"}
+ ))
+ else:
+ documents.append(Document(
+ page_content="",
+ metadata={"file_path": filename, "source": "llm_retriever", "error": "File not found"}
+ ))
+ except Exception as e:
+ logger.warning(f"Failed to read file {filename}: {e}")
+
+ return documents
+ except Exception as e:
+ logger.error(f"LLMRetriever failed: {e}")
+ return []
+
+ def _ask_llm_to_retrieve(self, user_query: str) -> List[str]:
+ """Feeds the file hierarchy and user query to the LLM."""
+
+ system_prompt = f"""
+You are a senior software engineer helping to navigate a codebase.
+Your task is to identify the top {self.top_k} files in the repository that are most likely to contain the answer to the user's query.
+
+Here is the file structure of the repository:
+{self.repo_structure}
+
+Rules:
+1. Respond ONLY with a list of file paths, one per line.
+2. Do not include any explanation or conversational text.
+3. Select files that are relevant to: "{user_query}"
+4. If the file paths in the structure are relative, return them as they appear in the structure.
+"""
+ messages = [
+ SystemMessage(content=system_prompt),
+ HumanMessage(content=f"User Query: {user_query}")
+ ]
+
+ response = self.llm.invoke(messages)
+ text = response.content.strip()
+ logger.info(f"DEBUG: Raw LLM Response: {text}")
+
+ # Parse response
+ lines = text.split('\n')
+ selected_files = []
+ for line in lines:
+ cleaned = line.strip().strip("- ").strip("* ")
+ if cleaned:
+ # Validate if it exists in our known files (fuzzy match if needed)
+ match = self._find_best_match(cleaned)
+ if match:
+ selected_files.append(match)
+
+ return list(set(selected_files))[:self.top_k]
+
+ def _find_best_match(self, filename: str) -> Optional[str]:
+ """Finds the closest matching filename from the repo."""
+ if filename in self.repo_files:
+ return filename
+
+ # 1. Try exact match on basename
+ for f in self.repo_files:
+ if os.path.basename(f) == filename:
+ return f
+
+ # 2. Fuzzy match
+ best_match = None
+ min_dist = float('inf')
+
+ for f in self.repo_files:
+ # We compare with the full path or just the end?
+ # Let's compare with the full path since LLM sees the structure.
+ dist = Levenshtein.distance(filename, f)
+ if dist < min_dist:
+ min_dist = dist
+ best_match = f
+
+ if min_dist < 20: # Arbitrary threshold
+ return best_match
+
+ return None
diff --git a/code_chatbot/prompts.py b/code_chatbot/prompts.py
new file mode 100644
index 0000000000000000000000000000000000000000..71644eabfe2cb12a259e1acdc3bca3bcc3dc9767
--- /dev/null
+++ b/code_chatbot/prompts.py
@@ -0,0 +1,341 @@
+# prompts.py - Enhanced Prompts for Code Chatbot
+
+SYSTEM_PROMPT_AGENT = """You are an expert software engineering assistant with deep expertise in code analysis, architecture, and feature development for the codebase: {repo_name}.
+
+Your mission is to help developers understand, navigate, and enhance their codebase through intelligent analysis and contextual responses.
+
+**CORE CAPABILITIES:**
+
+1. **Code Understanding & Explanation**:
+ - Analyze code structure, patterns, and architectural decisions
+ - Explain complex logic in clear, digestible terms
+ - Trace execution flows and data transformations
+ - Identify dependencies and component relationships
+
+2. **Strategic Tool Usage**:
+ Available tools and when to use them:
+ - `search_codebase(query)`: Find relevant code by semantic meaning or keywords
+ * Use multiple searches with different queries for complex questions
+ * Search for: function names, class names, patterns, concepts
+ - `read_file(file_path)`: Get complete file contents for detailed analysis
+ * Use when you need full context (imports, class structure, etc.)
+ - `list_files(directory)`: Understand project organization
+ * Use to map out module structure or find related files
+ - `find_callers(function_name)`: Find all functions that CALL a specific function
+ * Use for: "What uses this function?", "Where is this called from?"
+ * Great for impact analysis and understanding dependencies
+ - `find_callees(function_name)`: Find all functions a specific function CALLS
+ * Use for: "What does this function do?", "What are its dependencies?"
+ * Great for understanding implementation details
+ - `find_call_chain(start_func, end_func)`: Find the call path between two functions
+ * Use for: "How does execution flow from main() to save_data()?"
+ * Great for tracing complex workflows
+
+3. **Answer Structure** (adapt based on question complexity):
+
+ For "How does X work?" questions:
+````markdown
+ ## Overview
+ [2-3 sentence high-level explanation]
+
+ ## Implementation Details
+ [Step-by-step breakdown with code references]
+
+ ## Key Components
+ - **File**: `path/to/file.py`
+ - **Function/Class**: `name` (lines X-Y)
+ - **Purpose**: [what it does]
+
+ ## Code Example
+```language
+ [Actual code from the codebase with inline comments]
+```
+
+ ## Flow Diagram (if complex)
+ [Text-based flow or numbered steps]
+
+ ## Related Components
+ [Files/modules that interact with this feature]
+````
+
+ For "Where is X?" questions:
+````markdown
+ ## Location
+ **File**: `path/to/file.py` (lines X-Y)
+
+ ## Code Snippet
+```language
+ [Relevant code]
+```
+
+ ## Context
+ [Brief explanation of how it fits in the architecture]
+````
+
+ For "Add/Implement X" requests:
+````markdown
+ ## Proposed Implementation
+ [High-level approach aligned with existing patterns]
+
+ ## Code Changes
+
+ ### 1. Create/Modify: `path/to/file.py`
+```language
+ [New or modified code following project conventions]
+```
+
+ ### 2. [Additional files if needed]
+
+ ## Integration Points
+ - [Where this connects to existing code]
+ - [Any dependencies or imports needed]
+
+ ## Considerations
+ - [Edge cases, security, performance notes]
+````
+
+4. **Quality Standards**:
+ - ✅ Always cite specific files with paths (e.g., `src/auth/login.py:45-67`)
+ - ✅ Use actual code from the codebase, never generic placeholders
+ - ✅ Explain the "why" - architectural reasoning, design patterns used
+ - ✅ Maintain consistency with existing code style and patterns
+ - ✅ Highlight potential issues, edge cases, or important constraints
+ - ✅ When suggesting code, follow the project's naming conventions and structure
+ - ❌ Don't make assumptions - use tools to verify information
+ - ❌ Don't provide incomplete answers - use multiple tool calls if needed
+
+5. **Response Principles**:
+ - **Grounded**: Every statement should reference actual code
+ - **Complete**: Answer should eliminate need for follow-up questions
+ - **Practical**: Include actionable information and concrete examples
+ - **Contextual**: Explain how components fit into broader architecture
+ - **Honest**: If information is missing or unclear, explicitly state it
+
+**WORKFLOW**:
+1. Analyze the question to identify what information is needed
+2. Use tools strategically to gather comprehensive context
+3. Synthesize information into a structured, clear answer
+4. Validate that all claims are backed by actual code references
+
+**SPECIAL INSTRUCTIONS FOR FEATURE REQUESTS**:
+When users ask to "add", "implement", or "create" features:
+1. First, search for similar existing implementations in the codebase
+2. Identify the architectural patterns and conventions used
+3. Propose code that aligns with existing style and structure
+4. Show exact file modifications with before/after if modifying existing code
+5. List any new dependencies or configuration changes needed
+
+**CRITICAL OUTPUT RULES:**
+1. **NO HTML**: Do NOT generate HTML tags (like ``, ``, etc.). Use ONLY standard Markdown.
+2. **NO UI MIMICRY**: Do NOT attempt to recreate UI elements like "source chips", buttons, or widgets.
+3. **NO HALLUCINATION**: Only cite files that actually exist in the retrieved context.
+
+Remember: You're not just answering questions - you're helping developers deeply understand and confidently modify their codebase.
+"""
+
+SYSTEM_PROMPT_LINEAR_RAG = """You are an expert software engineering assistant analyzing the codebase: {repo_name}.
+
+You have been provided with relevant code snippets retrieved from the codebase. Your task is to deliver a comprehensive, accurate answer that demonstrates deep understanding.
+
+**YOUR APPROACH:**
+
+1. **Analyze the Retrieved Context**:
+ - Review all provided code snippets carefully
+ - Identify the most relevant pieces for the question
+ - Note relationships between different code sections
+ - Recognize patterns, conventions, and architectural decisions
+
+2. **Construct Your Answer**:
+
+ **Structure Guidelines**:
+ - Start with a clear, direct answer to the question
+ - Organize with markdown headers (##) for major sections
+ - Use code blocks with language tags: ```python, ```javascript, etc.
+ - Reference specific files with paths and line numbers
+ - Use bullet points for lists of components or steps
+
+ **Content Requirements**:
+ - Quote relevant code snippets from the provided context
+ - Explain what the code does AND why it's designed that way
+ - Describe how different components interact
+ - Highlight important patterns, conventions, or architectural decisions
+ - Mention edge cases, error handling, or special considerations
+ - Connect the answer to broader system architecture when relevant
+
+3. **Code Presentation**:
+ - Always introduce code snippets with context (e.g., "In `src/auth.py`, the login handler:")
+ - Add inline comments to complex code for clarity
+ - Show imports and dependencies when relevant
+ - Indicate if code is simplified or truncated
+
+4. **Completeness Checklist**:
+ - [ ] Direct answer to the user's question
+ - [ ] Supporting code from the actual codebase
+ - [ ] Explanation of implementation approach
+ - [ ] File paths and locations cited
+ - [ ] Architectural context provided
+ - [ ] Related components mentioned
+
+**RETRIEVED CODE CONTEXT:**
+
+{context}
+
+---
+
+**ANSWER GUIDELINES:**
+- Be thorough but not verbose - every sentence should add value
+- Use technical precision - this is for experienced developers
+- Maintain consistency with the codebase's terminology and concepts
+- If the context doesn't fully answer the question, explicitly state what's missing
+- Prioritize accuracy over speculation - only discuss what you can verify from the code
+
+**OUTPUT FORMAT:**
+Provide your answer in well-structured markdown that a developer can immediately understand and act upon.
+
+**CRITICAL RULES:**
+- **NO HTML**: Do NOT generate HTML tags. Use ONLY standard Markdown.
+- **NO UI MIMICRY**: Do NOT try to create "source chips" or other UI elements.
+"""
+
+QUERY_EXPANSION_PROMPT = """Given a user question about a codebase, generate 3-5 diverse search queries optimized for semantic code search.
+
+**User Question:** {question}
+
+**Generate queries that cover:**
+1. **Direct Implementation**: Specific function/class names, file patterns
+2. **Conceptual/Semantic**: High-level concepts, feature names, problem domains
+3. **Related Systems**: Connected components, dependencies, integrations
+4. **Configuration/Setup**: Environment setup, constants, configuration files
+5. **Usage Examples**: Test files, example usage, API endpoints (if applicable)
+
+**Query Strategy:**
+- Mix specific technical terms with natural language
+- Include variations of terminology (e.g., "authentication", "auth", "login")
+- Consider both questions ("how does X work") and keywords ("X implementation")
+- Target different levels of abstraction (high-level concepts → specific details)
+
+**Output Format** (one query per line, no numbering):
+[query 1]
+[query 2]
+[query 3]
+[query 4]
+[query 5]
+
+Generate 3-5 queries based on question complexity:
+"""
+
+ANSWER_SYNTHESIS_PROMPT = """You are synthesizing information from multiple code search results to provide a comprehensive answer.
+
+**User Question:** {question}
+
+**Retrieved Information from Codebase:**
+{retrieved_context}
+
+**Your Task:**
+Create a unified, well-structured answer that:
+
+1. **Integrates All Sources**:
+ - Combine overlapping information intelligently
+ - Resolve any apparent contradictions
+ - Build a complete picture from fragments
+
+2. **Maintains Traceability**:
+ - Cite which files each piece of information comes from
+ - Format: "In `path/to/file.py:line-range`, ..."
+ - Include code snippets from the retrieved context
+
+3. **Adds Value**:
+ - Explain relationships between components
+ - Highlight architectural patterns
+ - Provide context on why things are implemented this way
+ - Note dependencies and integration points
+
+4. **Structured Presentation**:
+````markdown
+ ## Direct Answer
+ [Concise 2-3 sentence response to the question]
+
+ ## Detailed Explanation
+ [Comprehensive breakdown with code references]
+
+ ## Key Code Components
+ [List important files, functions, classes with their roles]
+
+ ## Code Examples
+ [Relevant snippets from retrieved context with explanations]
+
+ ## Additional Context
+ [Architecture notes, related features, considerations]
+````
+
+5. **Handle Gaps**:
+ - If information is incomplete, clearly state what's provided vs. what's missing
+ - Distinguish between definite facts from code vs. reasonable inferences
+ - Don't fabricate details not present in the retrieved context
+
+**Quality Criteria:**
+- Every claim backed by retrieved code
+- Clear file and location citations
+- Practical, actionable information
+- Appropriate technical depth for the question
+- Well-organized with markdown formatting
+
+Provide your synthesized answer:
+"""
+
+# Additional utility prompts for specific scenarios
+
+CODE_MODIFICATION_PROMPT = """You are suggesting code modifications for the codebase: {repo_name}.
+
+**User Request:** {user_request}
+
+**Existing Code Context:**
+{existing_code}
+
+**Your Task:**
+Provide a concrete implementation that:
+1. Follows existing code style and patterns from the codebase
+2. Integrates seamlessly with current architecture
+3. Handles edge cases and errors appropriately
+4. Includes necessary imports and dependencies
+
+**Output Format:**
+## Implementation Approach
+[Brief explanation of your solution and why it fits the codebase]
+
+## Code Changes
+
+### File: `path/to/file.py`
+````python
+# Add these imports at the top
+[new imports if needed]
+
+# Add/modify this code at line X or in function Y
+[your implementation with comments]
+````
+
+### [Additional files if needed]
+
+## Integration Notes
+- [How this connects to existing code]
+- [Any configuration or dependency updates needed]
+- [Testing considerations]
+
+## Edge Cases Handled
+- [List important edge cases your code addresses]
+"""
+
+ARCHITECTURE_EXPLANATION_PROMPT = """Explain the architecture and design patterns used in {repo_name} for: {topic}
+
+**Code Context:**
+{context}
+
+**Provide:**
+1. **High-Level Architecture**: Overall structure and component organization
+2. **Design Patterns**: Specific patterns used (MVC, Repository, Factory, etc.)
+3. **Data Flow**: How information moves through the system
+4. **Key Decisions**: Why this architecture was chosen
+5. **Diagram** (text-based): Visual representation of component relationships
+
+Format with clear sections and reference specific files.
+"""
\ No newline at end of file
diff --git a/code_chatbot/rag.py b/code_chatbot/rag.py
new file mode 100644
index 0000000000000000000000000000000000000000..df0866698a3d6c6ee9c4bbe5961bbada3011278c
--- /dev/null
+++ b/code_chatbot/rag.py
@@ -0,0 +1,304 @@
+from typing import List, Tuple, Any, Optional
+import logging
+from langchain_google_genai import ChatGoogleGenerativeAI
+from langchain_groq import ChatGroq
+from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
+from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
+from langchain_core.retrievers import BaseRetriever
+# Simplified implementation that works with current langchain version
+# We'll implement history-aware retrieval manually
+from code_chatbot.reranker import Reranker
+from code_chatbot.retriever_wrapper import build_enhanced_retriever
+import os
+
+# Configure logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+class ChatEngine:
+ def __init__(
+ self,
+ retriever: BaseRetriever,
+ model_name: str = "gpt-4o",
+ provider: str = "openai",
+ api_key: str = None,
+ repo_name: Optional[str] = None,
+ use_agent: bool = True,
+ use_multi_query: bool = False,
+ use_reranking: bool = True,
+ repo_files: Optional[List[str]] = None,
+ repo_dir: str = ".", # New Argument
+ ):
+ self.base_retriever = retriever
+ self.model_name = model_name
+ self.provider = provider
+ self.api_key = api_key
+ self.repo_name = repo_name or "codebase"
+ self.use_agent = use_agent
+ self.use_multi_query = use_multi_query
+ self.use_reranking = use_reranking
+ self.repo_files = repo_files
+ self.repo_dir = repo_dir
+
+ # Initialize LLM
+ self.llm = self._get_llm()
+
+ # Initialize conversation history
+ self.chat_history = []
+
+ # Build enhanced vector retriever
+ self.vector_retriever = build_enhanced_retriever(
+ base_retriever=retriever,
+ llm=self.llm if use_multi_query else None, # Only for query expansion
+ use_multi_query=use_multi_query,
+ use_reranking=use_reranking,
+ )
+
+ # Initialize LLM Retriever if files are available
+ self.llm_retriever = None
+ if self.repo_files:
+ try:
+ from code_chatbot.llm_retriever import LLMRetriever
+ from langchain.retrievers import EnsembleRetriever
+
+ logger.info(f"Initializing LLMRetriever with {len(self.repo_files)} files.")
+ self.llm_retriever = LLMRetriever(
+ llm=self.llm,
+ repo_files=self.repo_files,
+ top_k=3
+ )
+
+ # Combine retrievers
+ self.retriever = EnsembleRetriever(
+ retrievers=[self.vector_retriever, self.llm_retriever],
+ weights=[0.6, 0.4]
+ )
+ except ImportError as e:
+ logger.warning(f"Could not load EnsembleRetriever or LLMRetriever: {e}")
+ self.retriever = self.vector_retriever
+ else:
+ self.retriever = self.vector_retriever
+
+ # Initialize Agent Graph if enabled
+ self.agent_executor = None
+ self.code_analyzer = None
+ if self.use_agent:
+ try:
+ from code_chatbot.agent_workflow import create_agent_graph
+ from code_chatbot.ast_analysis import EnhancedCodeAnalyzer
+ import os
+
+ logger.info(f"Building Agentic Workflow Graph for {self.repo_dir}...")
+
+ # Try to load code analyzer from saved graph
+ graph_path = os.path.join(self.repo_dir, "ast_graph.graphml") if self.repo_dir else None
+ if graph_path and os.path.exists(graph_path):
+ try:
+ import networkx as nx
+ self.code_analyzer = EnhancedCodeAnalyzer()
+ self.code_analyzer.graph = nx.read_graphml(graph_path)
+ logger.info(f"Loaded code analyzer with {self.code_analyzer.graph.number_of_nodes()} nodes")
+ except Exception as e:
+ logger.warning(f"Failed to load code analyzer: {e}")
+
+ self.agent_executor = create_agent_graph(
+ self.llm, self.retriever, self.repo_name,
+ self.repo_dir, self.provider, self.code_analyzer
+ )
+ except Exception as e:
+ logger.error(f"Failed to build Agent Graph: {e}")
+ self.use_agent = False
+
+ def _get_llm(self):
+ """Initialize the LLM based on provider (only Groq and Gemini supported)."""
+ api_key = self.api_key or os.getenv(f"{self.provider.upper()}_API_KEY")
+
+ if self.provider == "gemini":
+ if not api_key:
+ if not os.getenv("GOOGLE_API_KEY"):
+ raise ValueError("Google API Key is required for Gemini")
+
+ return ChatGoogleGenerativeAI(
+ model=self.model_name or "gemini-2.5-flash",
+ google_api_key=api_key,
+ temperature=0.2, # Low temp for agents
+ convert_system_message_to_human=True
+ )
+ elif self.provider == "groq":
+ if not api_key:
+ if not os.getenv("GROQ_API_KEY"):
+ raise ValueError("Groq API Key is required")
+
+ return ChatGroq(
+ model=self.model_name or "llama-3.3-70b-versatile",
+ groq_api_key=api_key,
+ temperature=0.2
+ )
+ else:
+ raise ValueError(f"Provider {self.provider} not supported. Only 'groq' and 'gemini' are supported.")
+
+
+ def _build_rag_chain(self):
+ """Builds a simplified RAG chain with history-aware retrieval."""
+ # For compatibility, we'll use a simpler approach that works with current langchain
+ # The history-aware retriever will be implemented in the chat method
+ return None # We'll handle retrieval manually in chat()
+
+ def _contextualize_query(self, question: str, history: List) -> str:
+ """Contextualize query based on chat history."""
+ if not history:
+ return question
+
+ # Build context from history
+ history_text = ""
+ for i in range(0, len(history), 2):
+ if i < len(history) and isinstance(history[i], HumanMessage):
+ history_text += f"User: {history[i].content}\n"
+ if i + 1 < len(history) and isinstance(history[i + 1], AIMessage):
+ history_text += f"Assistant: {history[i + 1].content}\n"
+
+ # Simple contextualization - just use the question for now
+ # In a full implementation, you'd use an LLM to rewrite the query
+ return question # Simplified
+
+ def chat(self, question: str) -> Tuple[str, List[dict]]:
+ """
+ Ask a question to the chatbot.
+ Uses Agentic Workflow if enabled, otherwise falls back to Linear RAG.
+ """
+ try:
+ # 1. Agentic Mode
+ if self.use_agent and self.agent_executor:
+ logger.info("Executing Agentic Workflow...")
+
+ # Contextualize with history
+ # Use comprehensive system prompt for high-quality answers
+ from code_chatbot.prompts import SYSTEM_PROMPT_AGENT
+ sys_content = SYSTEM_PROMPT_AGENT.format(repo_name=self.repo_name)
+ system_msg = SystemMessage(content=sys_content)
+
+ # Token Optimization: Only pass last 4 messages (2 turns) to keep context light.
+ recent_history = self.chat_history[-4:] if self.chat_history else []
+
+ inputs = {
+ "messages": [system_msg] + recent_history + [HumanMessage(content=question)]
+ }
+
+ # Run the graph
+ try:
+ final_state = self.agent_executor.invoke(inputs, config={"recursion_limit": 20})
+
+ # Extract Answer
+ messages = final_state["messages"]
+ raw_content = messages[-1].content
+
+ # Handle Gemini's multi-part content
+ if isinstance(raw_content, list):
+ answer = ""
+ for block in raw_content:
+ if isinstance(block, dict) and block.get('type') == 'text':
+ answer += block.get('text', '')
+ elif isinstance(block, str):
+ answer += block
+ answer = answer.strip() or str(raw_content)
+ else:
+ answer = raw_content
+
+ # Update history
+ self.chat_history.append(HumanMessage(content=question))
+ self.chat_history.append(AIMessage(content=answer))
+ if len(self.chat_history) > 20: self.chat_history = self.chat_history[-20:]
+
+ return answer, []
+
+ except Exception as e:
+ # Fallback for Groq/LLM Tool Errors & Rate Limits
+ error_str = str(e)
+ if any(err in error_str for err in ["tool_use_failed", "invalid_request_error", "400", "429", "RESOURCE_EXHAUSTED"]):
+ logger.warning(f"Agent failed ({error_str}), falling back to Linear RAG.")
+ return self._linear_chat(question)
+ raise e
+
+ # 2. Linear RAG Mode (Fallback)
+ return self._linear_chat(question)
+
+ except Exception as e:
+ logger.error(f"Error during chat: {e}", exc_info=True)
+ return f"Error: {str(e)}", []
+
+ def _linear_chat(self, question: str) -> Tuple[str, List[dict]]:
+ """Legacy Linear RAG implementation."""
+ """
+ Ask a question to the chatbot with history-aware retrieval.
+
+ Returns:
+ Tuple of (answer, sources) where sources is a list of dicts with file_path and url
+ """
+ try:
+ # Contextualize query based on history
+ contextualized_query = self._contextualize_query(question, self.chat_history)
+
+ # Retrieve relevant documents
+ docs = self.retriever.invoke(contextualized_query)
+ logger.info(f"Retrieved {len(docs)} documents")
+
+ if not docs:
+ return "I don't have any information about this codebase. Please make sure the codebase has been indexed properly.", []
+
+ # Build context from documents
+ context_text = "\n\n".join([
+ f"File: {doc.metadata.get('file_path', 'unknown')}\n{doc.page_content[:500]}..."
+ for doc in docs[:5] # Limit to top 5 docs
+ ])
+
+ # Extract sources
+ sources = []
+ for doc in docs[:5]:
+ file_path = doc.metadata.get("file_path") or doc.metadata.get("source", "unknown")
+ sources.append({
+ "file_path": file_path,
+ "url": doc.metadata.get("url", f"file://{file_path}"),
+ })
+
+ # Build prompt with history
+ qa_system_prompt = (
+ f"You are a Code Chatbot, an expert software engineering assistant helping me quickly understand "
+ f"a codebase called {self.repo_name}.\n"
+ "Assume I am an advanced developer and answer my questions in the most succinct way possible.\n"
+ "Always provide code examples where relevant.\n"
+ "Link your answers to specific files if possible.\n\n"
+ "Here are some snippets from the codebase:\n\n"
+ f"{context_text}"
+ )
+
+ # Build messages with history
+ messages = [SystemMessage(content=qa_system_prompt)]
+
+ # Add chat history
+ for msg in self.chat_history[-10:]: # Last 10 messages for context
+ messages.append(msg)
+
+ # Add current question
+ messages.append(HumanMessage(content=question))
+
+ # Get response from LLM
+ response_msg = self.llm.invoke(messages)
+ answer = response_msg.content
+
+ # Update chat history
+ self.chat_history.append(HumanMessage(content=question))
+ self.chat_history.append(AIMessage(content=answer))
+
+ # Keep history manageable (last 20 messages)
+ if len(self.chat_history) > 20:
+ self.chat_history = self.chat_history[-20:]
+
+ return answer, sources
+
+ except Exception as e:
+ logger.error(f"Error during chat: {e}", exc_info=True)
+ return f"Error: {str(e)}", []
+
+ def clear_memory(self):
+ """Clear the conversation history."""
+ self.chat_history.clear()
diff --git a/code_chatbot/rate_limiter.py b/code_chatbot/rate_limiter.py
new file mode 100644
index 0000000000000000000000000000000000000000..dae3e9e7d9fd73a914f45d0eb1c9c36db33a66a9
--- /dev/null
+++ b/code_chatbot/rate_limiter.py
@@ -0,0 +1,170 @@
+"""
+Smart Rate Limiter with Adaptive Delays and Caching
+Helps maximize chat usage within free tier limits
+"""
+
+import time
+import logging
+from typing import Optional, Dict, Any
+from datetime import datetime, timedelta
+from functools import lru_cache
+import hashlib
+
+logger = logging.getLogger(__name__)
+
+class RateLimiter:
+ """
+ Adaptive rate limiter that:
+ 1. Tracks API usage per provider
+ 2. Implements smart delays
+ 3. Caches responses for repeated queries
+ 4. Provides usage statistics
+ """
+
+ def __init__(self, provider: str = "gemini"):
+ self.provider = provider
+ self.request_times = []
+ self.token_usage = {"input": 0, "output": 0, "total": 0}
+ self.last_request_time = None
+
+ # Load configuration (with fallbacks if config file missing)
+ try:
+ import rate_limit_config as config
+ except ImportError:
+ # Use defaults if config not found
+ class config:
+ GEMINI_RPM = 15
+ GEMINI_MIN_DELAY = 2.0
+ GEMINI_BURST_DELAY = 8.0
+ GROQ_RPM = 30
+ GROQ_MIN_DELAY = 1.0
+ GROQ_BURST_DELAY = 10.0
+ ENABLE_CACHE = True
+ CACHE_TTL = 300
+
+ # Provider-specific limits
+ self.limits = {
+ "gemini": {
+ "rpm": config.GEMINI_RPM,
+ "min_delay": config.GEMINI_MIN_DELAY,
+ "burst_delay": config.GEMINI_BURST_DELAY,
+ },
+ "groq": {
+ "rpm": config.GROQ_RPM,
+ "min_delay": config.GROQ_MIN_DELAY,
+ "burst_delay": config.GROQ_BURST_DELAY,
+ }
+ }
+
+ self.response_cache = {} if config.ENABLE_CACHE else None
+ self.cache_ttl = config.CACHE_TTL
+
+ def get_cache_key(self, query: str, context_hash: str = "") -> str:
+ """Generate cache key for a query"""
+ combined = f"{query}:{context_hash}"
+ return hashlib.md5(combined.encode()).hexdigest()
+
+ def get_cached_response(self, cache_key: str) -> Optional[Dict[str, Any]]:
+ """Check if we have a cached response"""
+ if self.response_cache is None:
+ return None
+ if cache_key in self.response_cache:
+ cached_data, timestamp = self.response_cache[cache_key]
+ if time.time() - timestamp < self.cache_ttl:
+ logger.info(f"🎯 Cache hit! Saved an API call.")
+ return cached_data
+ else:
+ # Expired, remove it
+ del self.response_cache[cache_key]
+ return None
+
+ def cache_response(self, cache_key: str, response: Dict[str, Any]):
+ """Cache a response"""
+ if self.response_cache is None:
+ return
+ self.response_cache[cache_key] = (response, time.time())
+ # Keep cache size manageable
+ if len(self.response_cache) > 100:
+ # Remove oldest entries
+ sorted_items = sorted(self.response_cache.items(), key=lambda x: x[1][1])
+ for key, _ in sorted_items[:20]: # Remove 20 oldest
+ del self.response_cache[key]
+
+ def calculate_smart_delay(self) -> float:
+ """
+ Calculate optimal delay based on recent usage.
+ Returns delay in seconds.
+ """
+ config = self.limits.get(self.provider, self.limits["gemini"])
+
+ # Clean old request times (older than 1 minute)
+ cutoff = time.time() - 60
+ self.request_times = [t for t in self.request_times if t > cutoff]
+
+ # Check if we're approaching the rate limit
+ requests_last_minute = len(self.request_times)
+
+ if requests_last_minute >= config["rpm"] * 0.9: # 90% of limit
+ logger.warning(f"⚠️ Approaching rate limit ({requests_last_minute}/{config['rpm']} RPM)")
+ return config["burst_delay"]
+ elif requests_last_minute >= config["rpm"] * 0.7: # 70% of limit
+ return config["min_delay"] * 1.5
+ else:
+ return config["min_delay"]
+
+ def wait_if_needed(self):
+ """
+ Smart wait that adapts to usage patterns.
+ Only waits when necessary to avoid rate limits.
+ """
+ if self.last_request_time is None:
+ self.last_request_time = time.time()
+ self.request_times.append(time.time())
+ return
+
+ delay = self.calculate_smart_delay()
+ elapsed = time.time() - self.last_request_time
+
+ if elapsed < delay:
+ wait_time = delay - elapsed
+ logger.info(f"⏱️ Smart delay: waiting {wait_time:.1f}s to avoid rate limit...")
+ time.sleep(wait_time)
+
+ self.last_request_time = time.time()
+ self.request_times.append(time.time())
+
+ def record_usage(self, input_tokens: int = 0, output_tokens: int = 0):
+ """Track token usage for statistics"""
+ self.token_usage["input"] += input_tokens
+ self.token_usage["output"] += output_tokens
+ self.token_usage["total"] += (input_tokens + output_tokens)
+
+ def get_usage_stats(self) -> Dict[str, Any]:
+ """Get current usage statistics"""
+ cutoff = time.time() - 60
+ recent_requests = len([t for t in self.request_times if t > cutoff])
+
+ return {
+ "provider": self.provider,
+ "requests_last_minute": recent_requests,
+ "total_tokens": self.token_usage["total"],
+ "input_tokens": self.token_usage["input"],
+ "output_tokens": self.token_usage["output"],
+ "cache_size": len(self.response_cache) if self.response_cache else 0
+ }
+
+ def reset_stats(self):
+ """Reset usage statistics"""
+ self.token_usage = {"input": 0, "output": 0, "total": 0}
+ self.request_times = []
+ logger.info("📊 Usage statistics reset")
+
+
+# Global rate limiters (one per provider)
+_rate_limiters: Dict[str, RateLimiter] = {}
+
+def get_rate_limiter(provider: str) -> RateLimiter:
+ """Get or create rate limiter for a provider"""
+ if provider not in _rate_limiters:
+ _rate_limiters[provider] = RateLimiter(provider)
+ return _rate_limiters[provider]
diff --git a/code_chatbot/reranker.py b/code_chatbot/reranker.py
new file mode 100644
index 0000000000000000000000000000000000000000..f48d397282f5f436ad50f8a708d3cddd7e513237
--- /dev/null
+++ b/code_chatbot/reranker.py
@@ -0,0 +1,39 @@
+import logging
+from typing import List
+from langchain_core.documents import Document
+from sentence_transformers import CrossEncoder
+
+logger = logging.getLogger(__name__)
+
+class Reranker:
+ """
+ Uses a Cross-Encoder to re-rank documents retrieved by the vector store.
+ This significantly improves precision by scoring the query against each document directly.
+ """
+ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
+ logger.info(f"Loading Reranker model: {model_name}")
+ self.model = CrossEncoder(model_name)
+
+ def rerank(self, query: str, documents: List[Document], top_k: int = 5) -> List[Document]:
+ if not documents:
+ return []
+
+ # Prepare pairs for scoring: [[query, doc_text], ...]
+ pairs = [[query, doc.page_content] for doc in documents]
+
+ # Predict scores
+ scores = self.model.predict(pairs)
+
+ # Attach scores to docs and sort
+ scored_docs = []
+ for i, doc in enumerate(documents):
+ # We can store the score in metadata if needed
+ doc.metadata["rerank_score"] = float(scores[i])
+ scored_docs.append((doc, scores[i]))
+
+ # Sort by score descending
+ scored_docs.sort(key=lambda x: x[1], reverse=True)
+
+ # Return top_k
+ top_docs = [doc for doc, score in scored_docs[:top_k]]
+ return top_docs
diff --git a/code_chatbot/retriever_wrapper.py b/code_chatbot/retriever_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..8248e10510328ce7e5078e35715879ded83e952d
--- /dev/null
+++ b/code_chatbot/retriever_wrapper.py
@@ -0,0 +1,96 @@
+"""Wrapper retriever that adds reranking and multi-query support."""
+
+import logging
+from typing import List, Optional, Any
+from langchain_core.retrievers import BaseRetriever
+from langchain_core.documents import Document
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from code_chatbot.reranker import Reranker
+
+# Try to import MultiQueryRetriever - may not be available in all versions
+try:
+ from langchain.retrievers.multi_query import MultiQueryRetriever
+except ImportError:
+ try:
+ from langchain_community.retrievers import MultiQueryRetriever
+ except ImportError:
+ MultiQueryRetriever = None # type: ignore
+
+logger = logging.getLogger(__name__)
+
+
+class RerankingRetriever(BaseRetriever):
+ """Wraps a base retriever and applies reranking to results."""
+
+ base_retriever: BaseRetriever
+ reranker: Any
+ top_k: int = 5
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def __init__(self, base_retriever: BaseRetriever, reranker: Reranker, top_k: int = 5):
+ super().__init__(base_retriever=base_retriever, reranker=reranker, top_k=top_k)
+
+ def _get_relevant_documents(
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
+ ) -> List[Document]:
+ """Retrieve documents and rerank them."""
+ # Get documents from base retriever
+ docs = self.base_retriever.invoke(query)
+ logger.info(f"Base retriever returned {len(docs)} documents")
+
+ if not docs:
+ return []
+
+ # Rerank
+ reranked_docs = self.reranker.rerank(query, docs, top_k=self.top_k)
+ logger.info(f"Reranked to {len(reranked_docs)} top documents")
+
+ return reranked_docs
+
+
+def build_enhanced_retriever(
+ base_retriever: BaseRetriever,
+ llm=None,
+ use_multi_query: bool = False,
+ use_reranking: bool = True,
+ rerank_top_k: int = 5,
+) -> BaseRetriever:
+ """
+ Builds an enhanced retriever with optional multi-query expansion and reranking.
+
+ Args:
+ base_retriever: The base retriever (e.g., from vector store)
+ llm: LLM for multi-query expansion (required if use_multi_query=True)
+ use_multi_query: Whether to use multi-query retriever for query expansion
+ use_reranking: Whether to apply reranking
+ rerank_top_k: Number of top documents to return after reranking
+ """
+ retriever = base_retriever
+
+ # Apply multi-query expansion if requested
+ if use_multi_query:
+ if MultiQueryRetriever is None:
+ logger.warning("MultiQueryRetriever not available, skipping multi-query expansion")
+ elif not llm:
+ logger.warning("Multi-query retriever requires an LLM, skipping multi-query expansion")
+ else:
+ retriever = MultiQueryRetriever.from_llm(
+ retriever=retriever,
+ llm=llm
+ )
+ logger.info("Applied multi-query retriever for query expansion")
+
+ # Apply reranking if requested
+ if use_reranking:
+ reranker = Reranker()
+ retriever = RerankingRetriever(
+ base_retriever=retriever,
+ reranker=reranker,
+ top_k=rerank_top_k
+ )
+ logger.info("Applied reranking to retriever")
+
+ return retriever
+
diff --git a/code_chatbot/tools.py b/code_chatbot/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb8e8ebba5a0dd85967dca9f1696f5170767eed
--- /dev/null
+++ b/code_chatbot/tools.py
@@ -0,0 +1,183 @@
+import os
+import glob
+from typing import List, Optional
+from langchain_core.tools import tool
+from pydantic import BaseModel, Field
+
+# Define Input Schemas
+class ListFilesInput(BaseModel):
+ path: str = Field(description="Directory path to list files from. Use '.' for root.")
+
+class ReadFileInput(BaseModel):
+ file_path: str = Field(description="Path to the file to read.")
+
+# Define Tools Factory
+def get_filesystem_tools(root_dir: str = "."):
+ """Returns a list of tools bound to the specified root directory."""
+
+ # Ensure root_dir is absolute
+ root_dir = os.path.abspath(root_dir)
+
+ @tool("list_files", args_schema=ListFilesInput)
+ def list_files(path: str = ".") -> str:
+ """Lists files in the specified directory."""
+ try:
+ # Resolve target path relative to root_dir
+ if path == ".":
+ target_path = root_dir
+ else:
+ target_path = os.path.abspath(os.path.join(root_dir, path))
+
+ # Security check: ensure we are inside the codebase
+ if not target_path.startswith(root_dir):
+ return f"Error: Access denied. Path must be within the codebase: {root_dir}"
+
+ if not os.path.exists(target_path):
+ return f"Error: Path does not exist: {path}"
+
+ files = []
+ for item in os.listdir(target_path):
+ if item.startswith(".") and item != ".gitignore": continue
+
+ full_item_path = os.path.join(target_path, item)
+
+ if os.path.isdir(full_item_path):
+ files.append(f"{item}/")
+ else:
+ files.append(item)
+
+ # Sort for stability
+ files.sort()
+ return "\n".join(files)
+ except Exception as e:
+ return f"Error listing files: {e}"
+
+ @tool("read_file", args_schema=ReadFileInput)
+ def read_file(file_path: str) -> str:
+ """Reads the content of a file."""
+ try:
+ # Resolve full path
+ full_path = os.path.abspath(os.path.join(root_dir, file_path))
+
+ # Security check
+ if not full_path.startswith(root_dir):
+ return "Error: Access denied. File must be within the codebase."
+
+ if not os.path.exists(full_path):
+ return f"Error: File not found: {file_path}"
+
+ # Check file size to avoid overloading context
+ # Groq TPM limit is ~12k tokens. 12000 chars is roughly 3k tokens.
+ # We strictly prevent reading massive files to keep the agent alive.
+ if os.path.getsize(full_path) > 12000:
+ return f"Error: File '{file_path}' is too large ({os.path.getsize(full_path)} bytes). Read specific lines or functions instead."
+
+ with open(full_path, "r", errors='ignore') as f:
+ content = f.read()
+ return content
+ except Exception as e:
+ return f"Error reading file: {e}"
+
+ return [list_files, read_file]
+
+
+# ============================================================================
+# Call Graph Tools
+# ============================================================================
+
+class FindCallersInput(BaseModel):
+ function_name: str = Field(description="Name of the function to find callers for")
+
+class FindCalleesInput(BaseModel):
+ function_name: str = Field(description="Name of the function to find callees for")
+
+class FindCallChainInput(BaseModel):
+ start_function: str = Field(description="Name of the starting function")
+ end_function: str = Field(description="Name of the target function to trace to")
+
+
+def get_call_graph_tools(analyzer):
+ """Returns tools for querying the call graph."""
+
+ @tool("find_callers", args_schema=FindCallersInput)
+ def find_callers(function_name: str) -> str:
+ """Find all functions that call the specified function.
+ Useful for understanding: "Who uses this function?" or "What depends on this?"
+ """
+ if analyzer is None:
+ return "Error: No code analysis available. Index a codebase first."
+
+ try:
+ callers = analyzer.get_callers(function_name)
+
+ if not callers:
+ return f"No callers found for '{function_name}'. It may be unused or called dynamically."
+
+ result = f"Functions that call '{function_name}':\n"
+ for caller in callers:
+ parts = caller.split("::")
+ if len(parts) == 2:
+ result += f" - {parts[1]} (in {parts[0]})\n"
+ else:
+ result += f" - {caller}\n"
+
+ return result
+ except Exception as e:
+ return f"Error finding callers: {e}"
+
+ @tool("find_callees", args_schema=FindCalleesInput)
+ def find_callees(function_name: str) -> str:
+ """Find all functions that are called by the specified function.
+ Useful for understanding: "What does this function do?" or "What are its dependencies?"
+ """
+ if analyzer is None:
+ return "Error: No code analysis available. Index a codebase first."
+
+ try:
+ callees = analyzer.get_callees(function_name)
+
+ if not callees:
+ return f"No callees found for '{function_name}'. It may not call any other tracked functions."
+
+ result = f"Functions called by '{function_name}':\n"
+ for callee in callees:
+ parts = callee.split("::")
+ if len(parts) == 2:
+ result += f" - {parts[1]} (in {parts[0]})\n"
+ else:
+ result += f" - {callee}\n"
+
+ return result
+ except Exception as e:
+ return f"Error finding callees: {e}"
+
+ @tool("find_call_chain", args_schema=FindCallChainInput)
+ def find_call_chain(start_function: str, end_function: str) -> str:
+ """Find the call path from one function to another.
+ Useful for: "How does execution flow from main() to save_to_db()?"
+ """
+ if analyzer is None:
+ return "Error: No code analysis available. Index a codebase first."
+
+ try:
+ chains = analyzer.get_call_chain(start_function, end_function)
+
+ if not chains:
+ return f"No call path found from '{start_function}' to '{end_function}'."
+
+ result = f"Call paths from '{start_function}' to '{end_function}':\n\n"
+ for i, chain in enumerate(chains[:5], 1):
+ result += f"Path {i}:\n"
+ for j, node in enumerate(chain):
+ parts = node.split("::")
+ func_name = parts[1] if len(parts) == 2 else node
+ indent = " " * j
+ arrow = "-> " if j > 0 else ""
+ result += f"{indent}{arrow}{func_name}\n"
+ result += "\n"
+
+ return result
+ except Exception as e:
+ return f"Error finding call chain: {e}"
+
+ return [find_callers, find_callees, find_call_chain]
diff --git a/code_chatbot/universal_ingestor.py b/code_chatbot/universal_ingestor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c2f8b3bd4d0e6721453e6e867e1643fa623e1fd
--- /dev/null
+++ b/code_chatbot/universal_ingestor.py
@@ -0,0 +1,376 @@
+"""Universal ingestor that handles multiple input types: ZIP files, GitHub URLs, local directories, etc."""
+
+import logging
+import os
+import zipfile
+import requests
+import tempfile
+import shutil
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Generator, Tuple, Optional
+from urllib.parse import urlparse
+from pathlib import Path
+
+from langchain_core.documents import Document
+
+logger = logging.getLogger(__name__)
+
+
+class DataManager(ABC):
+ """Abstract base class for data managers."""
+
+ def __init__(self, dataset_id: str):
+ self.dataset_id = dataset_id
+
+ @abstractmethod
+ def download(self) -> bool:
+ """Downloads/prepares the data."""
+ pass
+
+ @abstractmethod
+ def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
+ """Yields (content, metadata) tuples for each file."""
+ pass
+
+
+class UniversalIngestor(DataManager):
+ """Factory class to ingest data from various sources."""
+
+ def __init__(self, source: str, local_dir: Optional[str] = None, **kwargs):
+ """
+ Args:
+ source: Can be:
+ - GitHub URL (e.g., "https://github.com/owner/repo")
+ - GitHub repo ID (e.g., "owner/repo")
+ - Local directory path
+ - ZIP file path
+ - Web URL
+ local_dir: Directory to store/clone/download data
+ **kwargs: Additional arguments for specific managers
+ """
+ super().__init__(dataset_id=source)
+ self.source = source
+ self.kwargs = kwargs
+ self.local_dir = local_dir or os.path.join(tempfile.gettempdir(), "code_chatbot")
+ self.delegate = self._detect_handler()
+
+ def _detect_handler(self) -> DataManager:
+ """Detects the type of input and returns the appropriate handler."""
+ source = self.source.strip()
+
+ # Check if it's a URL
+ if self._is_url(source):
+ if "github.com" in source or source.count("/") == 1 and "/" in source:
+ # GitHub URL or repo ID (owner/repo)
+ if "github.com" in source:
+ # Extract repo_id from URL
+ parts = urlparse(source).path.strip("/").split("/")
+ if len(parts) >= 2:
+ repo_id = f"{parts[0]}/{parts[1]}"
+ else:
+ raise ValueError(f"Invalid GitHub URL: {source}")
+ else:
+ # Assume it's owner/repo format
+ repo_id = source
+
+ return GitHubRepoManager(
+ repo_id=repo_id,
+ local_dir=self.local_dir,
+ **self.kwargs
+ )
+
+ # Other web URLs
+ return WebDocManager(source, local_dir=self.local_dir)
+
+ # Check if it's a ZIP file
+ if source.lower().endswith('.zip') and os.path.isfile(source):
+ return ZIPFileManager(source, local_dir=self.local_dir)
+
+ # Check if it's a local directory
+ if os.path.isdir(source):
+ return LocalDirectoryManager(source)
+
+ # Check if it's a local file
+ if os.path.isfile(source):
+ return LocalFileManager(source)
+
+ raise ValueError(f"Unable to determine source type for: {source}")
+
+ def _is_url(self, s: str) -> bool:
+ """Checks if a string is a URL."""
+ try:
+ result = urlparse(s)
+ return bool(result.scheme and result.netloc)
+ except Exception:
+ # Check if it looks like owner/repo (GitHub format)
+ if "/" in s and s.count("/") == 1 and not os.path.exists(s):
+ return True
+ return False
+
+ @property
+ def local_path(self) -> str:
+ """Returns the local path where data is stored."""
+ if hasattr(self.delegate, "local_path"):
+ return self.delegate.local_path
+ if hasattr(self.delegate, "path"):
+ return self.delegate.path
+ return self.local_dir
+
+ def download(self) -> bool:
+ """Downloads/prepares the data."""
+ return self.delegate.download()
+
+ def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
+ """Yields (content, metadata) tuples."""
+ yield from self.delegate.walk(get_content)
+
+
+class ZIPFileManager(DataManager):
+ """Handles ZIP file ingestion."""
+
+ def __init__(self, zip_path: str, local_dir: str):
+ super().__init__(dataset_id=zip_path)
+ self.zip_path = zip_path
+ self.local_dir = local_dir
+ self.path = os.path.join(local_dir, "extracted", os.path.basename(zip_path).replace('.zip', ''))
+
+ def download(self) -> bool:
+ """Extracts the ZIP file."""
+ if os.path.exists(self.path):
+ logger.info(f"ZIP already extracted to {self.path}")
+ return True
+
+ os.makedirs(self.path, exist_ok=True)
+
+ try:
+ with zipfile.ZipFile(self.zip_path, 'r') as zip_ref:
+ zip_ref.extractall(self.path)
+ logger.info(f"Extracted {self.zip_path} to {self.path}")
+ return True
+ except Exception as e:
+ logger.error(f"Failed to extract ZIP: {e}")
+ return False
+
+ def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
+ """Walks extracted files."""
+ if not os.path.exists(self.path):
+ return
+
+ IGNORE_DIRS = {'__pycache__', '.git', 'node_modules', 'venv', '.venv', '.env'}
+ IGNORE_EXTENSIONS = {
+ '.pyc', '.png', '.jpg', '.jpeg', '.gif', '.ico', '.svg', '.mp4', '.mov',
+ '.zip', '.tar', '.gz', '.pdf', '.exe', '.bin', '.pkl', '.npy', '.pt', '.pth'
+ }
+
+ for root, dirs, files in os.walk(self.path):
+ dirs[:] = [d for d in dirs if d not in IGNORE_DIRS and not d.startswith('.')]
+
+ for file in files:
+ if file.startswith('.'):
+ continue
+
+ file_path = os.path.join(root, file)
+ _, ext = os.path.splitext(file)
+ if ext.lower() in IGNORE_EXTENSIONS:
+ continue
+
+ rel_path = os.path.relpath(file_path, self.path)
+
+ if get_content:
+ try:
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
+ content = f.read()
+ yield content, {
+ "file_path": file_path,
+ "source": rel_path,
+ "file_name": file
+ }
+ except Exception as e:
+ logger.warning(f"Failed to read {file_path}: {e}")
+ else:
+ yield {"file_path": file_path, "source": rel_path, "file_name": file}
+
+
+class LocalDirectoryManager(DataManager):
+ """Handles local directory ingestion."""
+
+ def __init__(self, path: str):
+ super().__init__(dataset_id=path)
+ self.path = path
+ self.local_dir = path
+
+ def download(self) -> bool:
+ return os.path.isdir(self.path)
+
+ def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
+ """Walks local directory."""
+ IGNORE_DIRS = {'__pycache__', '.git', 'node_modules', 'venv', '.venv', '.env'}
+ IGNORE_EXTENSIONS = {
+ '.pyc', '.png', '.jpg', '.jpeg', '.gif', '.ico', '.svg', '.mp4', '.mov',
+ '.zip', '.tar', '.gz', '.pdf', '.exe', '.bin', '.pkl', '.npy', '.pt', '.pth'
+ }
+
+ for root, dirs, files in os.walk(self.path):
+ dirs[:] = [d for d in dirs if d not in IGNORE_DIRS and not d.startswith('.')]
+
+ for file in files:
+ if file.startswith('.'):
+ continue
+
+ file_path = os.path.join(root, file)
+ _, ext = os.path.splitext(file)
+ if ext.lower() in IGNORE_EXTENSIONS:
+ continue
+
+ rel_path = os.path.relpath(file_path, self.path)
+
+ if get_content:
+ try:
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
+ content = f.read()
+ yield content, {
+ "file_path": file_path,
+ "source": rel_path,
+ "url": f"file://{file_path}"
+ }
+ except Exception as e:
+ logger.warning(f"Skipping {file_path}: {e}")
+ else:
+ yield {"file_path": file_path, "source": rel_path}
+
+
+class LocalFileManager(DataManager):
+ """Handles single file ingestion."""
+
+ def __init__(self, path: str):
+ super().__init__(dataset_id=path)
+ self.path = path
+
+ def download(self) -> bool:
+ return os.path.exists(self.path)
+
+ def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
+ """Yields the single file."""
+ if get_content:
+ try:
+ with open(self.path, 'r', encoding='utf-8', errors='ignore') as f:
+ content = f.read()
+ yield content, {"file_path": self.path, "source": os.path.basename(self.path)}
+ except Exception as e:
+ logger.error(f"Failed to read {self.path}: {e}")
+ else:
+ yield {"file_path": self.path, "source": os.path.basename(self.path)}
+
+
+class GitHubRepoManager(DataManager):
+ """Handles GitHub repository cloning and ingestion."""
+
+ def __init__(self, repo_id: str, local_dir: str, access_token: Optional[str] = None, commit_hash: Optional[str] = None):
+ """
+ Args:
+ repo_id: GitHub repo in format "owner/repo"
+ local_dir: Directory to clone to
+ access_token: GitHub token for private repos
+ commit_hash: Optional commit hash to checkout
+ """
+ super().__init__(dataset_id=repo_id)
+ self.repo_id = repo_id
+ self.local_dir = local_dir
+ self.access_token = access_token or os.getenv("GITHUB_TOKEN")
+ self.commit_hash = commit_hash
+ self.path = os.path.join(local_dir, repo_id.replace("/", "_"))
+
+ def download(self) -> bool:
+ """Clones the GitHub repository."""
+ if os.path.exists(self.path) and os.listdir(self.path):
+ logger.info(f"Repo already cloned at {self.path}")
+ return True
+
+ try:
+ from git import Repo, GitCommandError
+
+ if self.access_token:
+ clone_url = f"https://{self.access_token}@github.com/{self.repo_id}.git"
+ else:
+ clone_url = f"https://github.com/{self.repo_id}.git"
+
+ os.makedirs(self.local_dir, exist_ok=True)
+
+ if self.commit_hash:
+ repo = Repo.clone_from(clone_url, self.path)
+ repo.git.checkout(self.commit_hash)
+ else:
+ Repo.clone_from(clone_url, self.path, depth=1, single_branch=True)
+
+ logger.info(f"Cloned {self.repo_id} to {self.path}")
+ return True
+ except ImportError:
+ logger.error("GitPython not installed. Install with: pip install gitpython")
+ raise
+ except Exception as e:
+ logger.error(f"Failed to clone {self.repo_id}: {e}")
+ return False
+
+ def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
+ """Walks cloned repository."""
+ if not os.path.exists(self.path):
+ return
+
+ # Use LocalDirectoryManager logic
+ manager = LocalDirectoryManager(self.path)
+ yield from manager.walk(get_content)
+
+
+class WebDocManager(DataManager):
+ """Handles web page/document ingestion."""
+
+ def __init__(self, url: str, local_dir: str):
+ super().__init__(dataset_id=url)
+ self.url = url
+ self.local_dir = local_dir
+
+ def download(self) -> bool:
+ """Checks if URL is accessible."""
+ try:
+ response = requests.get(self.url, timeout=10)
+ return response.status_code == 200
+ except Exception as e:
+ logger.error(f"Could not reach {self.url}: {e}")
+ return False
+
+ def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
+ """Fetches web page content."""
+ try:
+ response = requests.get(self.url, timeout=10)
+ if get_content:
+ from bs4 import BeautifulSoup
+ soup = BeautifulSoup(response.content, 'html.parser')
+ text = soup.get_text(separator='\n')
+ yield text, {"file_path": self.url, "url": self.url, "source": "web"}
+ else:
+ yield {"file_path": self.url, "url": self.url, "source": "web"}
+ except Exception as e:
+ logger.error(f"Failed to fetch {self.url}: {e}")
+
+
+def process_source(source: str, extract_to: str) -> Tuple[list, str]:
+ """
+ Convenience function to process any source type and return documents + local path.
+
+ Returns:
+ Tuple of (documents, local_path)
+ """
+ ingestor = UniversalIngestor(source, local_dir=extract_to)
+
+ if not ingestor.download():
+ raise ValueError(f"Failed to download/prepare source: {source}")
+
+ documents = []
+ for content, metadata in ingestor.walk(get_content=True):
+ documents.append(Document(
+ page_content=content,
+ metadata=metadata
+ ))
+
+ return documents, ingestor.local_path
+
diff --git a/rate_limit_config.py b/rate_limit_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..db59f9a98cb9c22da69bffbe6888fdb3e4aaeed9
--- /dev/null
+++ b/rate_limit_config.py
@@ -0,0 +1,63 @@
+# Rate Limit Configuration
+# Customize these settings to control API usage and maximize chat availability
+
+# ============================================================================
+# PROVIDER LIMITS (Free Tier Defaults)
+# ============================================================================
+
+# Gemini 2.0 Flash Experimental (Latest Model)
+GEMINI_RPM = 15 # Requests per minute
+GEMINI_TPM = 1000000 # Tokens per minute (1 million)
+GEMINI_MIN_DELAY = 4.0 # Minimum seconds between requests (60s / 15 RPM = 4s)
+GEMINI_BURST_DELAY = 10.0 # Delay when approaching limit
+
+# Groq Free Tier (Increased delays to prevent rate limits)
+GROQ_RPM = 30 # Requests per minute
+GROQ_TPM = 20000 # Conservative daily token estimate
+GROQ_MIN_DELAY = 8.0 # Minimum 8 seconds between requests (was 1s)
+GROQ_BURST_DELAY = 20.0 # Delay when approaching limit (was 10s)
+
+# ============================================================================
+# OPTIMIZATION SETTINGS
+# ============================================================================
+
+# Response Caching
+ENABLE_CACHE = True # Cache identical queries to save API calls
+CACHE_TTL = 300 # Cache lifetime in seconds (5 minutes)
+MAX_CACHE_SIZE = 100 # Maximum number of cached responses
+
+# Adaptive Delays
+USE_ADAPTIVE_DELAYS = True # Dynamically adjust delays based on usage
+RATE_LIMIT_THRESHOLD = 0.7 # Trigger longer delays at 70% of limit (0.0-1.0)
+
+# Context Optimization
+MAX_AGENT_TOOL_RESULTS = 5 # Number of search results per tool call
+MAX_AGENT_CONTENT_LENGTH = 2000 # Characters per search result
+MAX_LINEAR_DOCS = 8 # Number of documents for linear RAG
+MAX_LINEAR_CONTENT_LENGTH = 1500 # Characters per document
+
+# ============================================================================
+# ADVANCED SETTINGS
+# ============================================================================
+
+# Fallback Behavior
+AUTO_FALLBACK_TO_LINEAR = True # Fall back to linear RAG on agent rate limits
+MAX_AGENT_RETRIES = 2 # Number of retries on rate limit errors
+
+# Statistics & Monitoring
+SHOW_USAGE_STATS = True # Display usage stats in sidebar
+LOG_RATE_LIMIT_WARNINGS = True # Log when approaching limits
+
+# Token Budget (Optional - set to 0 to disable)
+# Stop making requests after hitting daily token budget
+DAILY_TOKEN_BUDGET_GEMINI = 0 # 0 = unlimited (within API limits)
+DAILY_TOKEN_BUDGET_GROQ = 0 # 0 = unlimited (within API limits)
+
+# ============================================================================
+# TIPS FOR MAXIMIZING USAGE
+# ============================================================================
+# 1. Set lower MIN_DELAY values for faster responses (but higher risk)
+# 2. Enable CACHE to avoid repeat API calls
+# 3. Reduce MAX_AGENT_TOOL_RESULTS if hitting rate limits frequently
+# 4. Use linear RAG mode for simpler questions (faster, fewer API calls)
+# 5. Switch providers if one is exhausted (Gemini <-> Groq)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d35b8b493d1f7fb32bde5c9750e8601a284c4fcb
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,20 @@
+langchain
+langchain-community
+langchain-core
+streamlit
+chromadb
+openai
+pydantic
+tiktoken
+watchdog
+langchain-google-genai
+python-dotenv
+langchain-groq
+tree-sitter
+tree-sitter-python
+tree-sitter-javascript
+networkx
+sentence-transformers
+gitpython
+beautifulsoup4
+pygments
diff --git a/sage/chat.py b/sage/chat.py
deleted file mode 100644
index 3ae10a13305bcd635f1387abb1f08a2b3dfc7618..0000000000000000000000000000000000000000
--- a/sage/chat.py
+++ /dev/null
@@ -1,128 +0,0 @@
-"""A gradio app that enables users to chat with their codebase.
-
-You must run `sage-index $GITHUB_REPO` first in order to index the codebase into a vector store.
-"""
-
-import logging
-
-import configargparse
-import gradio as gr
-from dotenv import load_dotenv
-from langchain.chains import create_history_aware_retriever, create_retrieval_chain
-from langchain.chains.combine_documents import create_stuff_documents_chain
-from langchain.schema import AIMessage, HumanMessage
-from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
-
-import sage.config as sage_config
-from sage.llm import build_llm_via_langchain
-from sage.retriever import build_retriever_from_args
-
-load_dotenv()
-
-
-def build_rag_chain(args):
- """Builds a RAG chain via LangChain."""
- llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
- retriever = build_retriever_from_args(args)
-
- # Prompt to contextualize the latest query based on the chat history.
- contextualize_q_system_prompt = (
- "Given a chat history and the latest user question which might reference context in the chat history, "
- "formulate a standalone question which can be understood without the chat history. Do NOT answer the question, "
- "just reformulate it if needed and otherwise return it as is."
- )
- contextualize_q_prompt = ChatPromptTemplate.from_messages(
- [
- ("system", contextualize_q_system_prompt),
- MessagesPlaceholder("chat_history"),
- ("human", "{input}"),
- ]
- )
- contextualize_q_llm = llm.with_config(tags=["contextualize_q_llm"])
- history_aware_retriever = create_history_aware_retriever(contextualize_q_llm, retriever, contextualize_q_prompt)
-
- qa_system_prompt = (
- f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
- "Assume I am an advanced developer and answer my questions in the most succinct way possible."
- "\n\n"
- "Here are some snippets from the codebase."
- "\n\n"
- "{context}"
- )
- qa_prompt = ChatPromptTemplate.from_messages(
- [
- ("system", qa_system_prompt),
- MessagesPlaceholder("chat_history"),
- ("human", "{input}"),
- ]
- )
-
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
- rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
- return rag_chain
-
-
-def main():
- parser = configargparse.ArgParser(
- description="Batch-embeds a GitHub repository and its issues.", ignore_unknown_config_file_keys=True
- )
- parser.add(
- "--share",
- default=False,
- help="Whether to make the gradio app publicly accessible.",
- )
-
- validator = sage_config.add_all_args(parser)
- args = parser.parse_args()
- validator(args)
-
- rag_chain = build_rag_chain(args)
-
- def source_md(file_path: str, url: str) -> str:
- """Formats a context source in Markdown."""
- return f"[{file_path}]({url})"
-
- async def _predict(message, history):
- """Performs one RAG operation."""
- history_langchain_format = []
- for human, ai in history:
- history_langchain_format.append(HumanMessage(content=human))
- history_langchain_format.append(AIMessage(content=ai))
- history_langchain_format.append(HumanMessage(content=message))
-
- query_rewrite = ""
- response = ""
- async for event in rag_chain.astream_events(
- {
- "input": message,
- "chat_history": history_langchain_format,
- },
- version="v1",
- ):
- if event["name"] == "retrieve_documents" and "output" in event["data"]:
- sources = [(doc.metadata["file_path"], doc.metadata["url"]) for doc in event["data"]["output"]]
- # Deduplicate while preserving the order.
- sources = list(dict.fromkeys(sources))
- response += "## Sources:\n" + "\n".join([source_md(s[0], s[1]) for s in sources]) + "\n## Response:\n"
-
- elif event["event"] == "on_chat_model_stream":
- chunk = event["data"]["chunk"].content
-
- if "contextualize_q_llm" in event["tags"]:
- query_rewrite += chunk
- else:
- # This is the actual response to the user query.
- if not response:
- logging.info(f"Query rewrite: {query_rewrite}")
- response += chunk
- yield response
-
- gr.ChatInterface(
- _predict,
- title=args.repo_id,
- examples=["What does this repo do?", "Give me some sample code."],
- ).launch(share=args.share)
-
-
-if __name__ == "__main__":
- main()
diff --git a/sage/chunker.py b/sage/chunker.py
deleted file mode 100644
index 45931080b545331f9208cc3510216a00d1e67bad..0000000000000000000000000000000000000000
--- a/sage/chunker.py
+++ /dev/null
@@ -1,311 +0,0 @@
-"""Chunker abstraction and implementations."""
-
-import logging
-import os
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
-from functools import cached_property
-from typing import Any, Dict, List, Optional
-
-import nbformat
-import pygments
-import tiktoken
-from semchunk import chunk as chunk_via_semchunk
-from tree_sitter import Node
-from tree_sitter_language_pack import get_parser
-
-from sage.constants import TEXT_FIELD
-
-logger = logging.getLogger(__name__)
-tokenizer = tiktoken.get_encoding("cl100k_base")
-
-
-class Chunk:
- @abstractmethod
- def content(self) -> str:
- """The content of the chunk to be indexed."""
-
- @abstractmethod
- def metadata(self) -> Dict:
- """Metadata for the chunk to be indexed."""
-
-
-@dataclass
-class FileChunk(Chunk):
- """A chunk of code or text extracted from a file in the repository."""
-
- file_content: str # The content of the entire file, not just this chunk.
- file_metadata: Dict # Metadata of the entire file, not just this chunk.
- start_byte: int
- end_byte: int
-
- @cached_property
- def filename(self):
- if not "file_path" in self.file_metadata:
- raise ValueError("file_metadata must contain a 'file_path' key.")
- return self.file_metadata["file_path"]
-
- @cached_property
- def content(self) -> Optional[str]:
- """The text content to be embedded. Might contain information beyond just the text snippet from the file."""
- return self.filename + "\n\n" + self.file_content[self.start_byte : self.end_byte]
-
- @cached_property
- def metadata(self):
- """Converts the chunk to a dictionary that can be passed to a vector store."""
- # Some vector stores require the IDs to be ASCII.
- filename_ascii = self.filename.encode("ascii", "ignore").decode("ascii")
- chunk_metadata = {
- # Some vector stores require the IDs to be ASCII.
- "id": f"{filename_ascii}_{self.start_byte}_{self.end_byte}",
- "start_byte": self.start_byte,
- "end_byte": self.end_byte,
- "length": self.end_byte - self.start_byte,
- # Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
- # size limit. In that case, you can simply store the start/end bytes above, and fetch the content
- # directly from the repository when needed.
- TEXT_FIELD: self.content,
- }
- chunk_metadata.update(self.file_metadata)
- return chunk_metadata
-
- @cached_property
- def num_tokens(self):
- """Number of tokens in this chunk."""
- return len(tokenizer.encode(self.content, disallowed_special=()))
-
- def __eq__(self, other):
- if isinstance(other, Chunk):
- return (
- self.filename == other.filename
- and self.start_byte == other.start_byte
- and self.end_byte == other.end_byte
- )
- return False
-
- def __hash__(self):
- return hash((self.filename, self.start_byte, self.end_byte))
-
-
-class Chunker(ABC):
- """Abstract class for chunking a datum into smaller pieces."""
-
- @abstractmethod
- def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
- """Chunks a datum into smaller pieces."""
-
-
-class CodeFileChunker(Chunker):
- """Splits a code file into chunks of at most `max_tokens` tokens each."""
-
- def __init__(self, max_tokens: int):
- self.max_tokens = max_tokens
- self.text_chunker = TextFileChunker(max_tokens)
-
- @staticmethod
- def _get_language_from_filename(filename: str):
- """Returns a canonical name for the language of the file, based on its extension.
- Returns None if the language is unknown to the pygments lexer.
- """
- # pygments doesn't recognize .tsx files and returns None. So we need to special-case them.
- extension = os.path.splitext(filename)[1]
- if extension == ".tsx":
- return "tsx"
-
- try:
- lexer = pygments.lexers.get_lexer_for_filename(filename)
- return lexer.name.lower()
- except pygments.util.ClassNotFound:
- return None
-
- def _chunk_node(self, node: Node, file_content: str, file_metadata: Dict) -> List[FileChunk]:
- """Splits a node in the parse tree into a flat list of chunks."""
- node_chunk = FileChunk(file_content, file_metadata, node.start_byte, node.end_byte)
-
- if node_chunk.num_tokens <= self.max_tokens:
- return [node_chunk]
-
- if not node.children:
- # This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
- return self.text_chunker.chunk(file_content[node.start_byte : node.end_byte], file_metadata)
-
- chunks = []
- for child in node.children:
- chunks.extend(self._chunk_node(child, file_content, file_metadata))
-
- for chunk in chunks:
- # This should always be true. Otherwise there must be a bug in the code.
- assert chunk.num_tokens <= self.max_tokens
-
- # Merge neighboring chunks if their combined size doesn't exceed max_tokens. The goal is to avoid pathologically
- # small chunks that end up being undeservedly preferred by the retriever.
- merged_chunks = []
- for chunk in chunks:
- if not merged_chunks:
- merged_chunks.append(chunk)
- elif merged_chunks[-1].num_tokens + chunk.num_tokens < self.max_tokens - 50:
- # There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
- # at this point, because tokenization is not necessarily additive.
- merged = FileChunk(
- file_content,
- file_metadata,
- merged_chunks[-1].start_byte,
- chunk.end_byte,
- )
- if merged.num_tokens <= self.max_tokens:
- merged_chunks[-1] = merged
- else:
- merged_chunks.append(chunk)
- else:
- merged_chunks.append(chunk)
- chunks = merged_chunks
-
- for chunk in merged_chunks:
- # This should always be true. Otherwise there's a bug worth investigating.
- assert chunk.num_tokens <= self.max_tokens
-
- return merged_chunks
-
- @staticmethod
- def is_code_file(filename: str) -> bool:
- """Checks whether pygment & tree_sitter can parse the file as code."""
- language = CodeFileChunker._get_language_from_filename(filename)
- return language and language not in ["text only", "None"]
-
- @staticmethod
- def parse_tree(filename: str, content: str) -> List[str]:
- """Parses the code in a file and returns the parse tree."""
- language = CodeFileChunker._get_language_from_filename(filename)
-
- if not language or language in ["text only", "None"]:
- logging.debug("%s doesn't seem to be a code file.", filename)
- return None
-
- try:
- parser = get_parser(language)
- except LookupError:
- logging.debug("%s doesn't seem to be a code file.", filename)
- return None
- # This should never happen unless there's a bug in the code, but we'd rather not crash.
- except Exception as e:
- logging.warn("Failed to get parser for %s: %s", filename, e)
- return None
-
- tree = parser.parse(bytes(content, "utf8"))
-
- if not tree.root_node.children or tree.root_node.children[0].type == "ERROR":
- logging.warning("Failed to parse code in %s.", filename)
- return None
- return tree
-
- def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
- """Chunks a code file into smaller pieces."""
- file_content = content
- file_metadata = metadata
- file_path = metadata["file_path"]
-
- if not file_content.strip():
- return []
-
- tree = self.parse_tree(file_path, file_content)
- if tree is None:
- return []
-
- file_chunks = self._chunk_node(tree.root_node, file_content, file_metadata)
- for chunk in file_chunks:
- # Make sure that the chunk has content and doesn't exceed the max_tokens limit. Otherwise there must be
- # a bug in the code.
- assert (
- chunk.num_tokens <= self.max_tokens
- ), f"Chunk size {chunk.num_tokens} exceeds max_tokens {self.max_tokens}."
-
- return file_chunks
-
-
-class TextFileChunker(Chunker):
- """Wrapper around semchunk: https://github.com/umarbutler/semchunk."""
-
- def __init__(self, max_tokens: int):
- self.max_tokens = max_tokens
- self.count_tokens = lambda text: len(tokenizer.encode(text, disallowed_special=()))
-
- def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
- """Chunks a text file into smaller pieces."""
- file_content = content
- file_metadata = metadata
- file_path = file_metadata["file_path"]
-
- # We need to allocate some tokens for the filename, which is part of the chunk content.
- extra_tokens = self.count_tokens(file_path + "\n\n")
- text_chunks = chunk_via_semchunk(file_content, self.max_tokens - extra_tokens, self.count_tokens)
-
- file_chunks = []
- start = 0
- for text_chunk in text_chunks:
- # This assertion should always be true. Otherwise there's a bug worth finding.
- assert self.count_tokens(text_chunk) <= self.max_tokens - extra_tokens
-
- # Find the start/end positions of the chunks.
- start = file_content.index(text_chunk, start)
- if start == -1:
- logging.warning("Couldn't find semchunk in content: %s", text_chunk)
- else:
- end = start + len(text_chunk)
- file_chunks.append(FileChunk(file_content, file_metadata, start, end))
-
- start = end
-
- return file_chunks
-
-
-class IpynbFileChunker(Chunker):
- """Extracts the python code from a Jupyter notebook, removing all the boilerplate.
-
- Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb
- """
-
- def __init__(self, code_chunker: CodeFileChunker):
- self.code_chunker = code_chunker
-
- def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
- filename = metadata["file_path"]
-
- if not filename.lower().endswith(".ipynb"):
- logging.warn("IPYNBChunker is only for .ipynb files.")
- return []
-
- notebook = nbformat.reads(content, as_version=nbformat.NO_CONVERT)
- python_code = "\n".join([cell.source for cell in notebook.cells if cell.cell_type == "code"])
-
- tmp_metadata = {"file_path": filename.replace(".ipynb", ".py")}
- chunks = self.code_chunker.chunk(python_code, tmp_metadata)
-
- for chunk in chunks:
- # Update filenames back to .ipynb
- chunk.metadata["file_path"] = filename
- return chunks
-
-
-class UniversalFileChunker(Chunker):
- """Chunks a file into smaller pieces, regardless of whether it's code or text."""
-
- def __init__(self, max_tokens: int):
- self.max_tokens = max_tokens
- self.code_chunker = CodeFileChunker(max_tokens)
- self.ipynb_chunker = IpynbFileChunker(self.code_chunker)
- self.text_chunker = TextFileChunker(max_tokens)
-
- def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
- if not "file_path" in metadata:
- raise ValueError("metadata must contain a 'file_path' key.")
- file_path = metadata["file_path"]
-
- # Figure out the appropriate chunker to use.
- if file_path.lower().endswith(".ipynb"):
- chunker = self.ipynb_chunker
- elif CodeFileChunker.is_code_file(file_path):
- chunker = self.code_chunker
- else:
- chunker = self.text_chunker
-
- return chunker.chunk(content, metadata)
diff --git a/sage/code_symbols.py b/sage/code_symbols.py
deleted file mode 100644
index 19b7b1caef6e20e03c35ba4f212796e3c07887e8..0000000000000000000000000000000000000000
--- a/sage/code_symbols.py
+++ /dev/null
@@ -1,49 +0,0 @@
-"""Utilities to extract code symbols (class and method names) from code files."""
-
-import logging
-from typing import List, Tuple
-
-from tree_sitter import Node
-
-from sage.chunker import CodeFileChunker
-
-
-def _extract_classes_and_methods(node: Node, acc: List[Tuple[str, str]], parent_class: str = None):
- """Extracts classes and methods from a tree-sitter node and places them in the `acc` accumulator."""
- if node.type in ["class_definition", "class_declaration"]:
- class_name_node = node.child_by_field_name("name")
- if class_name_node:
- class_name = class_name_node.text.decode("utf-8")
- acc.append((class_name, None))
- for child in node.children:
- _extract_classes_and_methods(child, acc, class_name)
- elif node.type in ["function_definition", "method_definition"]:
- function_name_node = node.child_by_field_name("name")
- if function_name_node:
- acc.append((parent_class, function_name_node.text.decode("utf-8")))
- # We're not going deeper into a method. This means we're missing nested functions.
- else:
- for child in node.children:
- _extract_classes_and_methods(child, acc, parent_class)
-
-
-def get_code_symbols(file_path: str, content: str) -> List[Tuple[str, str]]:
- """Extracts code symbols from a file.
-
- Code symbols are tuples of the form (class_name, method_name). For classes, method_name is None. For methods
- that do not belong to a class, class_name is None.
- """
- if not CodeFileChunker.is_code_file(file_path):
- return []
-
- if not content:
- return []
-
- logging.info(f"Extracting code symbols from {file_path}")
- tree = CodeFileChunker.parse_tree(file_path, content)
- if not tree:
- return []
-
- classes_and_methods = []
- _extract_classes_and_methods(tree.root_node, classes_and_methods)
- return classes_and_methods
diff --git a/sage/config.py b/sage/config.py
deleted file mode 100644
index a0dc4eb4d8a29a954cd5d107330edb7ed21197e9..0000000000000000000000000000000000000000
--- a/sage/config.py
+++ /dev/null
@@ -1,427 +0,0 @@
-"""Utility methods to define and validate flags."""
-
-import argparse
-import importlib.resources as resources
-import logging
-import os
-import re
-from typing import Callable
-
-from configargparse import ArgumentParser
-
-from sage.reranker import RerankerProvider
-
-# Limits defined here: https://ai.google.dev/gemini-api/docs/models/gemini
-GEMINI_MAX_TOKENS_PER_CHUNK = 2048
-
-MARQO_MAX_CHUNKS_PER_BATCH = 64
-# The ADA embedder from OpenAI has a maximum of 8192 tokens.
-OPENAI_MAX_TOKENS_PER_CHUNK = 8192
-# The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
-OPENAI_MAX_CHUNKS_PER_BATCH = 2048
-# The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
-OPENAI_MAX_TOKENS_PER_JOB = 3_000_000
-
-# Note that OpenAI embedding models have fixed dimensions, however, taking a slice of them is possible.
-# See "Reducing embedding dimensions" under https://platform.openai.com/docs/guides/embeddings/use-cases and
-# https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-dimensions
-OPENAI_DEFAULT_EMBEDDING_SIZE = {
- "text-embedding-ada-002": 1536,
- "text-embedding-3-small": 1536,
- "text-embedding-3-large": 3072,
-}
-
-VOYAGE_MAX_CHUNKS_PER_BATCH = 128
-
-
-def get_voyage_max_tokens_per_batch(model: str) -> int:
- """Returns the maximum number of tokens per batch for the Voyage model.
- See https://docs.voyageai.com/reference/embeddings-api."""
- if model == "voyage-3-lite":
- return 1_000_000
- if model in ["voyage-3", "voyage-2"]:
- return 320_000
- return 120_000
-
-
-def get_voyage_embedding_size(model: str) -> int:
- """Returns the embedding size for the Voyage model. See https://docs.voyageai.com/docs/embeddings#model-choices."""
- if model == "voyage-3-lite":
- return 512
- if model == "voyage-2-code":
- return 1536
- return 1024
-
-
-def add_config_args(parser: ArgumentParser):
- """Adds configuration-related arguments to the parser."""
- parser.add(
- "--mode",
- choices=["local", "remote"],
- default="remote",
- help="Whether to use local-only resources or call third-party providers (remote).",
- )
- parser.add(
- "--config",
- is_config_file=True,
- help="Path to .yaml configuration file.",
- )
- args, _ = parser.parse_known_args()
- config_file = resources.files("sage").joinpath(f"configs/{args.mode}.yaml")
- parser.set_defaults(config=str(config_file))
- return lambda _: True
-
-
-def add_repo_args(parser: ArgumentParser) -> Callable:
- """Adds repository-related arguments to the parser and returns a validator."""
- parser.add("repo_id", help="The ID of the repository to index")
- parser.add("--commit-hash", help="Optional commit hash to checkout. When not provided, defaults to HEAD.")
- parser.add(
- "--local-dir",
- default="repos",
- help="The local directory to store the repository",
- )
- return validate_repo_args
-
-
-def add_embedding_args(parser: ArgumentParser) -> Callable:
- """Adds embedding-related arguments to the parser and returns a validator."""
- parser.add("--embedding-provider", default="marqo", choices=["openai", "voyage", "marqo", "gemini"])
- parser.add(
- "--embedding-model",
- type=str,
- default=None,
- help="The embedding model. Defaults to `text-embedding-ada-002` for OpenAI and `hf/e5-base-v2` for Marqo.",
- )
- parser.add(
- "--embedding-size",
- type=int,
- default=None,
- help="The embedding size to use for OpenAI text-embedding-3* models. Defaults to 1536 for small and 3072 for "
- "large. Note that no other OpenAI models support a dynamic embedding size, nor do models used with Marqo.",
- )
- parser.add(
- "--tokens-per-chunk",
- type=int,
- default=800,
- help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
- )
- parser.add(
- "--chunks-per-batch",
- type=int,
- help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.",
- )
- parser.add(
- "--max-embedding-jobs",
- type=int,
- help="Maximum number of embedding jobs to run. Specifying this might result in "
- "indexing only part of the repository, but prevents you from burning through OpenAI credits.",
- )
- return validate_embedding_args
-
-
-def add_vector_store_args(parser: ArgumentParser) -> Callable:
- """Adds vector store-related arguments to the parser and returns a validator."""
- parser.add(
- "--vector-store-provider", default="marqo", choices=["pinecone", "marqo", "chroma", "faiss", "milvus", "qdrant"]
- )
- parser.add("--index-name", default="sage", help="Index name for the vector store index.")
- parser.add(
- "--milvus-uri",
- default="milvus_sage.db",
- help="URI for milvus. We default it to milvus_sage.db",
- )
- parser.add(
- "--index-namespace",
- default=None,
- help="Index namespace for this repo. When not specified, we default it to a derivative of the repo name.",
- )
- parser.add(
- "--marqo-url",
- default="http://localhost:8882",
- help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
- )
- parser.add(
- "--retrieval-alpha",
- default=1.0,
- type=float,
- help="Takes effect for Pinecone retriever only. The weight of the dense (embeddings-based) vs sparse (BM25) "
- "encoder in the final retrieval score. A value of 0.0 means BM25 only, 1.0 means embeddings only.",
- )
- parser.add(
- "--retriever-top-k", default=25, type=int, help="The number of top documents to retrieve from the vector store."
- )
- parser.add(
- "--multi-query-retriever",
- action=argparse.BooleanOptionalAction,
- default=False,
- help="When set to True, we rewrite the query 5 times, perform retrieval for each rewrite, and take the union "
- "of retrieved documents. See https://python.langchain.com/v0.1/docs/modules/data_connection/retrievers/MultiQueryRetriever/.",
- )
- parser.add(
- "--llm-retriever",
- action=argparse.BooleanOptionalAction,
- default=True,
- help="When set to True, we use an LLM for retrieval: we pass the repository file hierarchy together with the "
- "user query and ask the LLM to choose relevant files solely based on their paths. No indexing will be done, so "
- "all the vector store / embedding arguments will be ignored.",
- )
- return validate_vector_store_args
-
-
-def add_indexing_args(parser: ArgumentParser) -> Callable:
- """Adds indexing-related arguments to the parser and returns a validator."""
- parser.add(
- "--include",
- help="Path to a file containing a list of extensions to include. One extension per line.",
- )
- parser.add(
- "--exclude",
- help="Path to a file containing a list of extensions to exclude. One extension per line.",
- )
- # Pass --no-index-repo in order to not index the repository.
- parser.add(
- "--index-repo",
- action=argparse.BooleanOptionalAction,
- default=True,
- help="Whether to index the repository. At least one of --index-repo and --index-issues must be True.",
- )
- # Pass --no-index-issues in order to not index the issues.
- parser.add(
- "--index-issues",
- action=argparse.BooleanOptionalAction,
- default=False,
- help="Whether to index GitHub issues. At least one of --index-repo and --index-issues must be True. When "
- "--index-issues is set, you must also set a GITHUB_TOKEN environment variable.",
- )
- # Pass --no-index-issue-comments in order to not index the comments of GitHub issues.
- parser.add(
- "--index-issue-comments",
- action=argparse.BooleanOptionalAction,
- default=False,
- help="Whether to index the comments of GitHub issues. This is only relevant if --index-issues is set. "
- "GitHub's API for downloading comments is quite slow. Indexing solely the body of an issue seems to bring most "
- "of the gains anyway.",
- )
- return validate_indexing_args
-
-
-def add_reranking_args(parser: ArgumentParser) -> Callable:
- """Adds reranking-related arguments to the parser."""
- parser.add("--reranker-provider", default="huggingface", choices=[r.value for r in RerankerProvider])
- parser.add(
- "--reranker-model",
- help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
- "SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models",
- )
- parser.add("--reranker-top-k", default=5, help="The number of top documents to return after reranking.")
- # Trivial validator (nothing to check).
- return lambda _: True
-
-
-def add_llm_args(parser: ArgumentParser) -> Callable:
- """Adds language model-related arguments to the parser."""
- parser.add("--llm-provider", default="ollama", choices=["openai", "anthropic", "ollama"])
- parser.add(
- "--llm-model",
- help="The LLM name. Must be supported by the provider specified via --llm-provider.",
- )
- # Trivial validator (nothing to check).
- return lambda _: True
-
-
-def add_all_args(parser: ArgumentParser) -> Callable:
- """Adds all arguments to the parser and returns a validator."""
- arg_validators = [
- add_config_args(parser),
- add_repo_args(parser),
- add_embedding_args(parser),
- add_vector_store_args(parser),
- add_reranking_args(parser),
- add_indexing_args(parser),
- add_llm_args(parser),
- ]
-
- def validate_all(args):
- for validator in arg_validators:
- validator(args)
-
- return validate_all
-
-
-def validate_repo_args(args):
- """Validates the configuration of the repository."""
- if not re.match(r"^[^/]+/[^/]+$", args.repo_id):
- raise ValueError("repo_id must be in the format 'owner/repo'")
-
-
-def _validate_openai_embedding_args(args):
- """Validates the configuration of the OpenAI batch embedder and sets defaults."""
- if args.embedding_provider == "openai" and not os.getenv("OPENAI_API_KEY"):
- raise ValueError("Please set the OPENAI_API_KEY environment variable.")
-
- if not args.embedding_model:
- args.embedding_model = "text-embedding-3-small"
-
- if args.embedding_model not in OPENAI_DEFAULT_EMBEDDING_SIZE.keys():
- raise ValueError(f"Unrecognized embeddings.model={args.embedding_model}")
-
- if not args.embedding_size:
- args.embedding_size = OPENAI_DEFAULT_EMBEDDING_SIZE.get(args.embedding_model)
-
- if not args.tokens_per_chunk:
- # https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.
- args.tokens_per_chunk = 800
- elif args.tokens_per_chunk > OPENAI_MAX_TOKENS_PER_CHUNK:
- args.tokens_per_chunk = OPENAI_MAX_TOKENS_PER_CHUNK
- logging.warning(
- f"OpenAI enforces a limit of {OPENAI_MAX_TOKENS_PER_CHUNK} tokens per chunk. "
- "Overwriting embeddings.tokens_per_chunk."
- )
-
- if not args.chunks_per_batch:
- args.chunks_per_batch = OPENAI_MAX_CHUNKS_PER_BATCH
- elif args.chunks_per_batch > OPENAI_MAX_CHUNKS_PER_BATCH:
- args.chunks_per_batch = OPENAI_MAX_CHUNKS_PER_BATCH
- logging.warning(
- f"OpenAI enforces a limit of {OPENAI_MAX_CHUNKS_PER_BATCH} chunks per batch. "
- "Overwriting embeddings.chunks_per_batch."
- )
-
- chunks_per_job = args.tokens_per_chunk * args.chunks_per_batch
- if chunks_per_job >= OPENAI_MAX_TOKENS_PER_JOB:
- raise ValueError(f"The maximum number of chunks per job is {OPENAI_MAX_TOKENS_PER_JOB}. Got {chunks_per_job}")
-
-
-def _validate_voyage_embedding_args(args):
- """Validates the configuration of the Voyage batch embedder and sets defaults."""
- if args.embedding_provider == "voyage" and not os.getenv("VOYAGE_API_KEY"):
- raise ValueError("Please set the VOYAGE_API_KEY environment variable.")
-
- if not args.embedding_model:
- args.embedding_model = "voyage-code-2"
-
- if not args.tokens_per_chunk:
- # https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.
- args.tokens_per_chunk = 800
-
- if not args.chunks_per_batch:
- args.chunks_per_batch = VOYAGE_MAX_CHUNKS_PER_BATCH
- elif args.chunks_per_batch > VOYAGE_MAX_CHUNKS_PER_BATCH:
- args.chunks_per_batch = VOYAGE_MAX_CHUNKS_PER_BATCH
- logging.warning(f"Voyage enforces a limit of {VOYAGE_MAX_CHUNKS_PER_BATCH} chunks per batch. Overwriting.")
-
- max_tokens = get_voyage_max_tokens_per_batch(args.embedding_model)
- if args.tokens_per_chunk * args.chunks_per_batch > max_tokens:
- raise ValueError(
- f"Voyage enforces a limit of {max_tokens} tokens per batch. "
- "Reduce either --tokens-per-chunk or --chunks-per-batch."
- )
-
- if not args.embedding_size:
- args.embedding_size = get_voyage_embedding_size(args.embedding_model)
-
-
-def _validate_marqo_embedding_args(args):
- """Validates the configuration of the Marqo batch embedder and sets defaults."""
- if not args.embedding_model:
- args.embedding_model = "hf/e5-base-v2"
-
- if not args.chunks_per_batch:
- args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH
- elif args.chunks_per_batch > MARQO_MAX_CHUNKS_PER_BATCH:
- args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH
- logging.warning(
- f"Marqo enforces a limit of {MARQO_MAX_CHUNKS_PER_BATCH} chunks per batch. "
- "Overwriting embeddings.chunks_per_batch."
- )
-
-
-def _validate_gemini_embedding_args(args):
- """Validates the configuration of the Gemini batch embedder and sets defaults."""
- if not args.embedding_model:
- args.embedding_model = "models/text-embedding-004"
- assert os.environ[
- "GOOGLE_API_KEY"
- ], "Please set the GOOGLE_API_KEY environment variable if using `gemini` embeddings."
- if not args.chunks_per_batch:
- # This value is reasonable but arbitrary (i.e. Gemini does not explicitly enforce a limit).
- args.chunks_per_batch = 2000
-
- if not args.tokens_per_chunk:
- args.tokens_per_chunk = GEMINI_MAX_TOKENS_PER_CHUNK
- if not args.embedding_size:
- args.embedding_size = 768
-
-
-def validate_embedding_args(args):
- """Validates the configuration of the batch embedder and sets defaults."""
- if args.llm_retriever:
- # When using an LLM to retrieve, we are not running the embedder.
- return True
- if args.embedding_provider == "openai":
- _validate_openai_embedding_args(args)
- elif args.embedding_provider == "voyage":
- _validate_voyage_embedding_args(args)
- elif args.embedding_provider == "marqo":
- _validate_marqo_embedding_args(args)
- elif args.embedding_provider == "gemini":
- _validate_gemini_embedding_args(args)
- else:
- raise ValueError(f"Unrecognized --embedding-provider={args.embedding_provider}")
-
-
-def validate_vector_store_args(args):
- """Validates the configuration of the vector store and sets defaults."""
- if args.llm_retriever:
- if not os.getenv("ANTHROPIC_API_KEY"):
- raise ValueError(
- "Please set the ANTHROPIC_API_KEY environment variable to use the LLM retriever. "
- "(We're constrained to Claude because we need prompt caching.)"
- )
-
- if args.index_issues:
- # The LLM retriever only makes sense on the code repository, since it passes file paths to the LLM.
- raise ValueError("Cannot use --index-issues with --llm-retriever.")
-
- # When using an LLM retriever, all the vector store arguments are ignored.
- return
-
- if not args.index_namespace:
- # Attempt to derive a default index namespace from the repository information.
- if "repo_id" not in args:
- raise ValueError("Please set a value for --index-namespace.")
- args.index_namespace = args.repo_id
- if "commit_hash" in args and args.commit_hash:
- args.index_namespace += "/" + args.commit_hash
- if args.vector_store_provider == "marqo":
- # Marqo namespaces must match this pattern: [a-zA-Z_-][a-zA-Z0-9_-]*
- args.index_namespace = re.sub(r"[^a-zA-Z0-9_-]", "_", args.index_namespace)
-
- if args.vector_store_provider == "marqo":
- if not args.marqo_url:
- args.marqo_url = "http://localhost:8882"
- if "/" in args.index_namespace:
- raise ValueError(f"Marqo doesn't allow slashes in --index-namespace={args.index_namespace}.")
-
- elif args.vector_store_provider == "pinecone":
- if not os.getenv("PINECONE_API_KEY"):
- raise ValueError("Please set the PINECONE_API_KEY environment variable.")
- if not args.index_name:
- raise ValueError(f"Please set the vector_store.index_name value.")
-
-
-def validate_indexing_args(args):
- """Validates the indexing configuration and sets defaults."""
- if args.include and args.exclude:
- raise ValueError("At most one of indexing.include and indexing.exclude can be specified.")
- if not args.include and not args.exclude:
- args.exclude = str(resources.files("sage").joinpath("sample-exclude.txt"))
- if args.include and not os.path.exists(args.include):
- raise ValueError(f"Path --include={args.include} does not exist.")
- if args.exclude and not os.path.exists(args.exclude):
- raise ValueError(f"Path --exclude={args.exclude} does not exist.")
- if not args.index_repo and not args.index_issues:
- raise ValueError("Either --index_repo or --index_issues must be set to true.")
- if args.index_issues and not os.getenv("GITHUB_TOKEN"):
- raise ValueError("Please set the GITHUB_TOKEN environment variable.")
diff --git a/sage/configs/local.yaml b/sage/configs/local.yaml
deleted file mode 100644
index 350fb20dccd7e977ea07de18e8a0b34b86b3e82c..0000000000000000000000000000000000000000
--- a/sage/configs/local.yaml
+++ /dev/null
@@ -1,16 +0,0 @@
-# Embeddings
-embedding-provider: marqo
-embedding-model: hf/e5-base-v2
-tokens-per-chunk: 800
-chunks-per-batch: 64
-
-# Vector store
-vector-store-provider: marqo
-
-# LLM
-llm-provider: ollama
-llm-model: llama3.1
-
-# Reranking
-reranking-provider: huggingface
-reranking-model: cross-encoder/ms-marco-MiniLM-L-6-v2
\ No newline at end of file
diff --git a/sage/configs/remote.yaml b/sage/configs/remote.yaml
deleted file mode 100644
index 77dddcfa76193e3b1493de2d710efaf8c0d02515..0000000000000000000000000000000000000000
--- a/sage/configs/remote.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-llm-retriever: true
-llm-provider: anthropic
-# Here we optimize for ease of setup, so we skip the reranker which would require an extra API key.
-reranker-provider: none
-# Since we skipped the reranker, we can't afford to feed the retriever with too many candidates.
-retriever-top-k: 5
-
-# The settings below (embeddings and vector store) are only relevant when setting --no-llm-retriever
-
-# Embeddings
-embedding-provider: openai
-embedding-model: text-embedding-3-small
-tokens-per-chunk: 800
-chunks-per-batch: 2000
-# Vector store
-vector-store-provider: pinecone
-pinecone-index-name: sage
-hybrid-retrieval: true
diff --git a/sage/constants.py b/sage/constants.py
deleted file mode 100644
index df1f1cd0b07720a54dc95699a3f97bf40aa25f20..0000000000000000000000000000000000000000
--- a/sage/constants.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# This is the key in the metadata that points to the actual text content of a document or chunk.
-# It can mostly be an arbitrary string, but certain classes in LangChain do expect it to be "text" specifically.
-TEXT_FIELD = "text"
diff --git a/sage/data_manager.py b/sage/data_manager.py
deleted file mode 100644
index 680f3fd13a6d80a5f6e7d663467c7e60769903f7..0000000000000000000000000000000000000000
--- a/sage/data_manager.py
+++ /dev/null
@@ -1,256 +0,0 @@
-"""Utility classes to maniuplate GitHub repositories."""
-
-import logging
-import os
-from abc import abstractmethod
-from functools import cached_property
-from typing import Any, Dict, Generator, Tuple
-
-import requests
-from git import GitCommandError, Repo
-
-
-class DataManager:
- def __init__(self, dataset_id: str):
- self.dataset_id = dataset_id
-
- @abstractmethod
- def download(self) -> bool:
- """Downloads the data from a remote location."""
-
- @abstractmethod
- def walk(self) -> Generator[Tuple[Any, Dict], None, None]:
- """Yields a tuple of (data, metadata) for each data item in the dataset."""
-
-
-class GitHubRepoManager(DataManager):
- """Class to manage a local clone of a GitHub repository."""
-
- def __init__(
- self,
- repo_id: str,
- commit_hash: str = None,
- access_token: str = None,
- local_dir: str = None,
- inclusion_file: str = None,
- exclusion_file: str = None,
- ):
- """
- Args:
- repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/sage".
- commit_hash: Optional commit hash to checkout. If not specified, we pull the latest version of the repo.
- access_token: A GitHub access token to use for cloning private repositories. Not needed for public repos.
- local_dir: The local directory where the repository will be cloned.
- inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of
- the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
- exclusion_file: A file with a lists of files/directories/extensions to exclude. Each line must be in one of
- the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
- """
- super().__init__(dataset_id=repo_id)
- self.repo_id = repo_id
- self.commit_hash = commit_hash
- self.access_token = access_token
-
- self.local_dir = local_dir or "/tmp/"
- if not os.path.exists(self.local_dir):
- os.makedirs(self.local_dir)
- self.local_path = os.path.join(self.local_dir, repo_id)
-
- self.log_dir = os.path.join(self.local_dir, "logs", repo_id)
- if not os.path.exists(self.log_dir):
- os.makedirs(self.log_dir)
-
- if inclusion_file and exclusion_file:
- raise ValueError("Only one of inclusion_file or exclusion_file should be provided.")
-
- self.inclusions = self._parse_filter_file(inclusion_file) if inclusion_file else None
- self.exclusions = self._parse_filter_file(exclusion_file) if exclusion_file else None
-
- @cached_property
- def is_public(self) -> bool:
- """Checks whether a GitHub repository is publicly visible."""
- response = requests.get(f"https://api.github.com/repos/{self.repo_id}", timeout=10)
- # Note that the response will be 404 for both private and non-existent repos.
- return response.status_code == 200
-
- @cached_property
- def default_branch(self) -> str:
- """Fetches the default branch of the repository from GitHub."""
- headers = {
- "Accept": "application/vnd.github.v3+json",
- }
- if self.access_token:
- headers["Authorization"] = f"token {self.access_token}"
-
- response = requests.get(f"https://api.github.com/repos/{self.repo_id}", headers=headers)
- if response.status_code == 200:
- branch = response.json().get("default_branch", "main")
- else:
- # This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
- # most common naming for the default branch ("main").
- logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}")
- branch = "main"
- return branch
-
- def download(self) -> bool:
- """Clones the repository to the local directory, if it's not already cloned."""
- if os.path.exists(self.local_path):
- # The repository is already cloned.
- return True
-
- if not self.is_public and not self.access_token:
- raise ValueError(f"Repo {self.repo_id} is private or doesn't exist.")
-
- if self.access_token:
- clone_url = f"https://{self.access_token}@github.com/{self.repo_id}.git"
- else:
- clone_url = f"https://github.com/{self.repo_id}.git"
-
- try:
- if self.commit_hash:
- repo = Repo.clone_from(clone_url, self.local_path)
- repo.git.checkout(self.commit_hash)
- else:
- Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
- except GitCommandError as e:
- logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e)
- return False
- return True
-
- def _parse_filter_file(self, file_path: str) -> bool:
- """Parses a file with files/directories/extensions to include/exclude.
-
- Lines are expected to be in the format:
- # Comment that will be ignored, or
- ext:.my-extension, or
- file:my-file.py, or
- dir:my-directory
- """
- with open(file_path, "r") as f:
- lines = f.readlines()
-
- parsed_data = {"ext": [], "file": [], "dir": []}
- for line in lines:
- if line.startswith("#"):
- # This is a comment line.
- continue
- key, value = line.strip().split(":")
- if key in parsed_data:
- parsed_data[key].append(value)
- else:
- logging.error("Unrecognized key in line: %s. Skipping.", line)
-
- return parsed_data
-
- def _should_include(self, file_path: str) -> bool:
- """Checks whether the file should be indexed."""
- # Exclude symlinks.
- if os.path.islink(file_path):
- return False
-
- # Exclude hidden files and directories.
- if any(part.startswith(".") for part in file_path.split(os.path.sep)):
- return False
-
- if not self.inclusions and not self.exclusions:
- return True
-
- # Filter based on file extensions, file names and directory names.
- _, extension = os.path.splitext(file_path)
- extension = extension.lower()
- file_name = os.path.basename(file_path)
- dirs = os.path.dirname(file_path).split("/")
-
- if self.inclusions:
- return (
- extension in self.inclusions.get("ext", [])
- or file_name in self.inclusions.get("file", [])
- or any(d in dirs for d in self.inclusions.get("dir", []))
- )
- elif self.exclusions:
- return (
- extension not in self.exclusions.get("ext", [])
- and file_name not in self.exclusions.get("file", [])
- and all(d not in dirs for d in self.exclusions.get("dir", []))
- )
- return True
-
- def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
- """Walks the local repository path and yields a tuple of (content, metadata) for each file.
- The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").
-
- Args:
- get_content: When set to True, yields (content, metadata) tuples. When set to False, yields metadata only.
- """
- # We will keep appending to these files during the iteration, so we need to clear them first.
- repo_name = self.repo_id.replace("/", "_")
- included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt")
- excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt")
- if os.path.exists(included_log_file):
- os.remove(included_log_file)
- logging.info("Logging included files at %s", included_log_file)
- if os.path.exists(excluded_log_file):
- os.remove(excluded_log_file)
- logging.info("Logging excluded files at %s", excluded_log_file)
-
- for root, _, files in os.walk(self.local_path):
- file_paths = [os.path.join(root, file) for file in files]
- included_file_paths = [f for f in file_paths if self._should_include(f)]
-
- with open(included_log_file, "a") as f:
- for path in included_file_paths:
- f.write(path + "\n")
-
- excluded_file_paths = set(file_paths).difference(set(included_file_paths))
- with open(excluded_log_file, "a") as f:
- for path in excluded_file_paths:
- f.write(path + "\n")
-
- for file_path in included_file_paths:
- relative_file_path = file_path[len(self.local_dir) + 1 :]
- metadata = {
- "file_path": relative_file_path,
- "url": self.url_for_file(relative_file_path),
- }
-
- if not get_content:
- yield metadata
- continue
-
- contents = self.read_file(relative_file_path)
- if contents:
- yield contents, metadata
-
- def url_for_file(self, file_path: str) -> str:
- """Converts a repository file path to a GitHub link."""
- file_path = file_path[len(self.repo_id) + 1 :]
- return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
-
- def read_file(self, relative_file_path: str) -> str:
- """Reads the contents of a file in the repository."""
- absolute_file_path = os.path.join(self.local_dir, relative_file_path)
- with open(absolute_file_path, "r") as f:
- try:
- contents = f.read()
- return contents
- except UnicodeDecodeError:
- logging.warning("Unable to decode file %s.", absolute_file_path)
- return None
-
- def from_args(args: Dict):
- """Creates a GitHubRepoManager from command-line arguments and clones the underlying repository."""
- repo_manager = GitHubRepoManager(
- repo_id=args.repo_id,
- commit_hash=args.commit_hash,
- access_token=os.getenv("GITHUB_TOKEN"),
- local_dir=args.local_dir,
- inclusion_file=args.include,
- exclusion_file=args.exclude,
- )
- success = repo_manager.download()
- if not success:
- raise ValueError(
- f"Unable to clone {args.repo_id}. Please check that it exists and you have access to it. "
- "For private repositories, please set the GITHUB_TOKEN variable in your environment."
- )
- return repo_manager
diff --git a/sage/embedder.py b/sage/embedder.py
deleted file mode 100644
index ff33d19be577b6ba8875af62834d7f1b9fcfc5dd..0000000000000000000000000000000000000000
--- a/sage/embedder.py
+++ /dev/null
@@ -1,442 +0,0 @@
-"""Batch embedder abstraction and implementations."""
-
-import json
-import logging
-import os
-import time
-from abc import ABC, abstractmethod
-from collections import Counter
-from typing import Dict, Generator, List, Optional, Tuple
-
-import google.generativeai as genai
-import marqo
-import requests
-from openai import OpenAI
-from tenacity import retry, stop_after_attempt, wait_random_exponential
-from tqdm import tqdm
-
-from sage.chunker import Chunk, Chunker
-from sage.constants import TEXT_FIELD
-from sage.data_manager import DataManager
-
-Vector = Tuple[Dict, List[float]] # (metadata, embedding)
-
-
-class BatchEmbedder(ABC):
- """Abstract class for batch embedding of a dataset."""
-
- @abstractmethod
- def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
- """Issues batch embedding jobs for the entire dataset."""
-
- @abstractmethod
- def embeddings_are_ready(self) -> bool:
- """Checks whether the batch embedding jobs are done."""
-
- @abstractmethod
- def download_embeddings(self) -> Generator[Vector, None, None]:
- """Yields (chunk_metadata, embedding) pairs for each chunk in the dataset."""
-
-
-class OpenAIBatchEmbedder(BatchEmbedder):
- """Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
-
- def __init__(
- self, data_manager: DataManager, chunker: Chunker, local_dir: str, embedding_model: str, embedding_size: int
- ):
- self.data_manager = data_manager
- self.chunker = chunker
- self.local_dir = local_dir
- self.embedding_model = embedding_model
- self.embedding_size = embedding_size
- self.client = OpenAI()
-
- def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None) -> str:
- """Issues batch embedding jobs for the entire dataset. Returns the filename containing the job IDs."""
- batch = []
- batch_ids = {} # job_id -> metadata
- chunk_count = 0
- dataset_name = self.data_manager.dataset_id.replace("/", "_")
-
- num_files = len([x for x in self.data_manager.walk(get_content=False)])
- pbar = tqdm(total=num_files, desc="Processing files", unit="file")
-
- for content, metadata in self.data_manager.walk():
- chunks = self.chunker.chunk(content, metadata)
- chunk_count += len(chunks)
- batch.extend(chunks)
- pbar.update(1)
-
- if len(batch) > chunks_per_batch:
- for i in range(0, len(batch), chunks_per_batch):
- sub_batch = batch[i : i + chunks_per_batch]
- openai_batch_id = self._issue_job_for_chunks(sub_batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
- batch_ids[openai_batch_id] = [chunk.metadata for chunk in sub_batch]
- if max_embedding_jobs and len(batch_ids) >= max_embedding_jobs:
- logging.info("Reached the maximum number of embedding jobs. Stopping.")
- return
- batch = []
-
- # Finally, commit the last batch.
- if batch:
- openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
- batch_ids[openai_batch_id] = [chunk.metadata for chunk in batch]
-
- logging.info("Issued %d jobs for %d chunks.", len(batch_ids), chunk_count)
-
- timestamp = int(time.time())
- metadata_file = os.path.join(self.local_dir, f"{dataset_name}_openai_batch_ids_{timestamp}.json")
- with open(metadata_file, "w") as f:
- json.dump(batch_ids, f)
- logging.info("Job metadata saved at %s", metadata_file)
- pbar.close()
- return metadata_file
-
- def embeddings_are_ready(self, metadata_file: str) -> bool:
- """Checks whether the embeddings jobs are done (either completed or failed).
-
- Args:
- metadata_file: Path to the file containing the job metadata (output of self.embed_dataset).
- """
- with open(metadata_file, "r") as f:
- batch_ids = json.load(f)
-
- job_ids = batch_ids.keys()
- statuses = [self.client.batches.retrieve(job_id.strip()) for job_id in job_ids]
- are_ready = all(status.status in ["completed", "failed"] for status in statuses)
- status_counts = Counter(status.status for status in statuses)
- logging.info("Job statuses: %s", status_counts)
- return are_ready
-
- def download_embeddings(
- self, metadata_file: str, store_file_chunk_content: bool = True
- ) -> Generator[Vector, None, None]:
- """Yields a (chunk_metadata, embedding) pair for each chunk in the dataset.
-
- Args:
- metadata_file: Path to the file containing the job metadata (output of self.embed_dataset).
- store_file_chunk_content: Whether to store the text content in the metadata for file chunks. Set this to
- False if you want to save space in the vector store. After retrieval, the content of a file chunk can be
- reconstructed based on the file_path, start_byte and end_byte fields in the metadata. This will not
- affect other types of chunks (e.g. GitHub issues) for which the content is harder to reconstruct.
- """
- with open(metadata_file, "r") as f:
- batch_ids = json.load(f)
-
- job_ids = batch_ids.keys()
- statuses = [self.client.batches.retrieve(job_id.strip()) for job_id in job_ids]
-
- for idx, status in enumerate(statuses):
- if status.status == "failed":
- logging.error("Job failed: %s", status)
- continue
-
- if not status.output_file_id:
- error = self.client.files.content(status.error_file_id)
- logging.error("Job %s failed with error: %s", status.id, error.text)
- continue
-
- batch_metadata = batch_ids[status.id]
- file_response = self.client.files.content(status.output_file_id)
- data = json.loads(file_response.text)["response"]["body"]["data"]
- logging.info("Job %s generated %d embeddings.", status.id, len(data))
-
- for datum in data:
- idx = int(datum["index"])
- metadata = batch_metadata[idx]
- if (
- not store_file_chunk_content
- and "file_path" in metadata
- and "start_byte" in metadata
- and "end_byte" in metadata
- ):
- metadata.pop(TEXT_FIELD, None)
- embedding = datum["embedding"]
- yield (metadata, embedding)
-
- def _issue_job_for_chunks(self, chunks: List[Chunk], batch_id: str) -> str:
- """Issues a batch embedding job for the given chunks. Returns the job ID."""
- logging.info("*" * 100)
- logging.info("Issuing job for batch %s with %d chunks.", batch_id, len(chunks))
-
- # Create a .jsonl file with the batch.
- request = OpenAIBatchEmbedder._chunks_to_request(chunks, batch_id, self.embedding_model, self.embedding_size)
- input_file = os.path.join(self.local_dir, f"batch_{batch_id}.jsonl")
- OpenAIBatchEmbedder._export_to_jsonl([request], input_file)
-
- # Uplaod the file and issue the embedding job.
- batch_input_file = self.client.files.create(file=open(input_file, "rb"), purpose="batch")
- batch_status = self._create_batch_job(batch_input_file.id)
- logging.info("Created job with ID %s", batch_status.id)
- return batch_status.id
-
- def _create_batch_job(self, input_file_id: str):
- """Creates a batch embedding job for OpenAI."""
- try:
- return self.client.batches.create(
- input_file_id=input_file_id,
- endpoint="/v1/embeddings",
- completion_window="24h", # This is the only allowed value for now.
- timeout=3 * 60, # 3 minutes
- metadata={},
- )
- except Exception as e:
- logging.error(f"Failed to create batch job with input_file_id={input_file_id}. Error: {e}")
- return None
-
- @staticmethod
- def _export_to_jsonl(list_of_dicts: List[Dict], output_file: str):
- """Exports a list of dictionaries to a .jsonl file."""
- directory = os.path.dirname(output_file)
- if not os.path.exists(directory):
- os.makedirs(directory)
- with open(output_file, "w") as f:
- for item in list_of_dicts:
- json.dump(item, f)
- f.write("\n")
-
- @staticmethod
- def _chunks_to_request(chunks: List[Chunk], batch_id: str, model: str, dimensions: Optional[int] = None) -> Dict:
- """Convert a list of chunks to a batch request."""
- body = {
- "model": model,
- "input": [chunk.content for chunk in chunks],
- }
-
- # These are the only two models that support a dynamic embedding size.
- if model in ["text-embedding-3-small", "text-embedding-3-large"] and dimensions is not None:
- body["dimensions"] = dimensions
-
- return {
- "custom_id": batch_id,
- "method": "POST",
- "url": "/v1/embeddings",
- "body": body,
- }
-
-
-class VoyageBatchEmbedder(BatchEmbedder):
- """Batch embedder that calls Voyage. See https://docs.voyageai.com/reference/embeddings-api."""
-
- def __init__(self, data_manager: DataManager, chunker: Chunker, embedding_model: str):
- self.data_manager = data_manager
- self.chunker = chunker
- self.embedding_model = embedding_model
- self.embedding_data = []
-
- def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
- """Issues batch embedding jobs for the entire dataset."""
- batch = []
- chunk_count = 0
-
- num_files = len([x for x in self.data_manager.walk(get_content=False)])
- pbar = tqdm(total=num_files, desc="Processing files", unit="file")
-
- for content, metadata in self.data_manager.walk():
- chunks = self.chunker.chunk(content, metadata)
- chunk_count += len(chunks)
- batch.extend(chunks)
- pbar.update(1)
-
- token_count = chunk_count * self.chunker.max_tokens
- if token_count % 900_000 == 0:
- logging.info("Pausing for 60 seconds to avoid rate limiting...")
- time.sleep(60) # Voyage API rate limits to 1m tokens per minute; we'll pause every 900k tokens.
-
- if len(batch) > chunks_per_batch:
- for i in range(0, len(batch), chunks_per_batch):
- sub_batch = batch[i : i + chunks_per_batch]
- logging.info("Embedding %d chunks...", len(sub_batch))
- result = self._make_batch_request(sub_batch)
- for chunk, datum in zip(sub_batch, result["data"]):
- self.embedding_data.append((chunk.metadata, datum["embedding"]))
- batch = []
-
- # Finally, commit the last batch.
- if batch:
- logging.info("Embedding %d chunks...", len(batch))
- result = self._make_batch_request(batch)
- for chunk, datum in zip(batch, result["data"]):
- self.embedding_data.append((chunk.metadata, datum["embedding"]))
- pbar.close()
- logging.info(f"Successfully embedded {chunk_count} chunks.")
-
- def embeddings_are_ready(self, *args, **kwargs) -> bool:
- """Checks whether the batch embedding jobs are done."""
- # The Voyage API is synchronous, so once embed_dataset() returns, the embeddings are ready.
- return True
-
- def download_embeddings(self, *args, **kwargs) -> Generator[Vector, None, None]:
- """Yields (chunk_metadata, embedding) pairs for each chunk in the dataset."""
- for chunk_metadata, embedding in self.embedding_data:
- yield (chunk_metadata, embedding)
-
- @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(6))
- def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
- """Makes a batch request to the Voyage API with exponential backoff when we hit rate limits."""
- url = "https://api.voyageai.com/v1/embeddings"
- headers = {"Authorization": f"Bearer {os.environ['VOYAGE_API_KEY']}", "Content-Type": "application/json"}
- payload = {"input": [chunk.content for chunk in chunks], "model": self.embedding_model}
-
- response = requests.post(url, json=payload, headers=headers)
- if not response.status_code == 200:
- raise ValueError(f"Failed to make batch request. Response: {response.text}")
-
- return response.json()
-
-
-class MarqoEmbedder(BatchEmbedder):
- """Embedder that uses the open-source Marqo vector search engine.
-
- Embeddings can be stored locally (in which case `url` the constructor should point to localhost) or in the cloud.
- """
-
- def __init__(self, data_manager: DataManager, chunker: Chunker, index_name: str, url: str, model="hf/e5-base-v2"):
- self.data_manager = data_manager
- self.chunker = chunker
- self.client = marqo.Client(url=url)
- self.index = self.client.index(index_name)
-
- all_index_names = [result["indexName"] for result in self.client.get_indexes()["results"]]
- if not index_name in all_index_names:
- self.client.create_index(index_name, model=model)
-
- def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
- """Issues batch embedding jobs for the entire dataset with progress tracking."""
- if chunks_per_batch > 64:
- raise ValueError("Marqo enforces a limit of 64 chunks per batch.")
-
- chunk_count = 0
- batch = []
- job_count = 0
-
- num_files = len([x for x in self.data_manager.walk(get_content=False)])
- pbar = tqdm(total=num_files, desc="Processing files", unit="file")
-
- for content, metadata in self.data_manager.walk():
- chunks = self.chunker.chunk(content, metadata)
- chunk_count += len(chunks)
- batch.extend(chunks)
- pbar.update(1)
- if len(batch) > chunks_per_batch:
- for i in range(0, len(batch), chunks_per_batch):
- sub_batch = batch[i : i + chunks_per_batch]
- logging.info("Indexing %d chunks...", len(sub_batch))
- self.index.add_documents(
- documents=[chunk.metadata for chunk in sub_batch],
- tensor_fields=[TEXT_FIELD],
- )
- job_count += 1
-
- if max_embedding_jobs and job_count >= max_embedding_jobs:
- logging.info("Reached the maximum number of embedding jobs. Stopping.")
- pbar.close()
- return
- batch = []
- if batch:
- self.index.add_documents(documents=[chunk.metadata for chunk in batch], tensor_fields=[TEXT_FIELD])
-
- pbar.close()
- logging.info(f"Successfully embedded {chunk_count} chunks.")
-
- def embeddings_are_ready(self) -> bool:
- """Checks whether the batch embedding jobs are done."""
- # Marqo indexes documents synchronously, so once embed_dataset() returns, the embeddings are ready.
- return True
-
- def download_embeddings(self) -> Generator[Vector, None, None]:
- """Yields (chunk_metadata, embedding) pairs for each chunk in the dataset."""
- # Marqo stores embeddings as they are created, so they're already in the vector store. No need to download them
- # as we would with e.g. OpenAI, Cohere, or some other cloud-based embedding service.
- return []
-
-
-class GeminiBatchEmbedder(BatchEmbedder):
- """Batch embedder that calls Gemini."""
-
- def __init__(self, data_manager: DataManager, chunker: Chunker, embedding_model: str):
- self.data_manager = data_manager
- self.chunker = chunker
- self.embedding_data = []
- self.embedding_model = embedding_model
- genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
-
- def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
- return genai.embed_content(
- model=self.embedding_model, content=[chunk.content for chunk in chunks], task_type="retrieval_document"
- )
-
- def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
- """Issues batch embedding jobs for the entire dataset."""
- batch = []
- chunk_count = 0
-
- request_count = 0
- last_request_time = time.time()
-
- num_files = len([x for x in self.data_manager.walk(get_content=False)])
- pbar = tqdm(total=num_files, desc="Processing files", unit="file")
-
- for content, metadata in self.data_manager.walk():
- chunks = self.chunker.chunk(content, metadata)
- chunk_count += len(chunks)
- batch.extend(chunks)
- pbar.update(1)
-
- if len(batch) > chunks_per_batch:
- for i in range(0, len(batch), chunks_per_batch):
- sub_batch = batch[i : i + chunks_per_batch]
- logging.info("Embedding %d chunks...", len(sub_batch))
- result = self._make_batch_request(sub_batch)
- for chunk, embedding in zip(sub_batch, result["embedding"]):
- self.embedding_data.append((chunk.metadata, embedding))
- request_count += 1
-
- # Check if we've made more than 1500 requests in the last minute
- # Rate limits here: https://ai.google.dev/gemini-api/docs/models/gemini
- current_time = time.time()
- elapsed_time = current_time - last_request_time
- if elapsed_time < 60 and request_count >= 1400:
- logging.info("Reached rate limit, pausing for 60 seconds...")
- time.sleep(60)
- last_request_time = current_time
- request_count = 0
- # Reset the last request time and request count if more than 60 sec have passed
- elif elapsed_time > 60:
- last_request_time = current_time
- request_count = 0
-
- batch = []
-
- # Finally, commit the last batch.
- if batch:
- logging.info("Embedding %d chunks...", len(batch))
- result = self._make_batch_request(batch)
- for chunk, embedding in zip(batch, result["embedding"]):
- self.embedding_data.append((chunk.metadata, embedding))
- pbar.close()
- logging.info(f"Successfully embedded {chunk_count} chunks.")
-
- def embeddings_are_ready(self, *args, **kwargs) -> bool:
- """Checks whether the batch embedding jobs are done."""
- return True
-
- def download_embeddings(self, *args, **kwargs) -> Generator[Vector, None, None]:
- """Yields (chunk_metadata, embedding) pairs for each chunk in the dataset."""
- for chunk_metadata, embedding in self.embedding_data:
- yield chunk_metadata, embedding
-
-
-def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
- if args.embedding_provider == "openai":
- return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
- elif args.embedding_provider == "voyage":
- return VoyageBatchEmbedder(data_manager, chunker, args.embedding_model)
- elif args.embedding_provider == "marqo":
- return MarqoEmbedder(
- data_manager, chunker, index_name=args.index_namespace, url=args.marqo_url, model=args.embedding_model
- )
- elif args.embedding_provider == "gemini":
- return GeminiBatchEmbedder(data_manager, chunker, embedding_model=args.embedding_model)
- else:
- raise ValueError(f"Unrecognized embedder type {args.embedding_provider}")
diff --git a/sage/github.py b/sage/github.py
deleted file mode 100644
index 934ed58982b7dfda4a5dc35f2e5d4396c917d7d9..0000000000000000000000000000000000000000
--- a/sage/github.py
+++ /dev/null
@@ -1,242 +0,0 @@
-"""GitHub-specific implementations for DataManager and Chunker."""
-
-import logging
-from dataclasses import dataclass
-from typing import Any, Dict, Generator, List, Tuple
-
-import requests
-import tiktoken
-
-from sage.chunker import Chunk, Chunker
-from sage.constants import TEXT_FIELD
-from sage.data_manager import DataManager
-
-tokenizer = tiktoken.get_encoding("cl100k_base")
-
-
-@dataclass
-class GitHubIssueComment:
- """A comment on a GitHub issue."""
-
- url: str
- html_url: str
- body: str
-
- @property
- def pretty(self):
- return f"""## Comment: {self.body}"""
-
-
-@dataclass
-class GitHubIssue:
- """A GitHub issue."""
-
- url: str
- html_url: str
- title: str
- body: str
- comments: List[GitHubIssueComment]
-
- @property
- def pretty(self):
- # Do not include the comments.
- return f"# Issue: {self.title}\n{self.body}"
-
-
-class GitHubIssuesManager(DataManager):
- """Class to manage the GitHub issues of a particular repository."""
-
- def __init__(self, repo_id: str, access_token: str, index_comments: bool = False, max_issues: int = None):
- super().__init__(dataset_id=repo_id + "/issues")
- self.repo_id = repo_id
- self.index_comments = index_comments
- self.max_issues = max_issues
- self.access_token = access_token
- if not self.access_token:
- raise ValueError("Please set the GITHUB_TOKEN environment variable when indexing GitHub issues.")
- self.issues = []
-
- def download(self) -> bool:
- """Downloads all open issues from a GitHub repository (including the comments)."""
- per_page = min(self.max_issues or 100, 100) # 100 is maximum per page
- url = f"https://api.github.com/repos/{self.repo_id}/issues?per_page={per_page}"
- while url:
- logging.info(f"Fetching issues from {url}")
- response = self._get_page_of_issues(url)
- response.raise_for_status()
- for issue in response.json():
- if not "pull_request" in issue:
- self.issues.append(
- GitHubIssue(
- url=issue["url"],
- html_url=issue["html_url"],
- title=issue["title"],
- # When there's no body, issue["body"] is None.
- body=issue["body"] or "",
- comments=self._get_comments(issue["comments_url"]) if self.index_comments else [],
- )
- )
- if self.max_issues and len(self.issues) >= self.max_issues:
- break
- url = GitHubIssuesManager._get_next_link_from_header(response)
- return True
-
- def walk(self) -> Generator[Tuple[Any, Dict], None, None]:
- """Yields a tuple of (issue_content, issue_metadata) for each GitHub issue in the repository."""
- for issue in self.issues:
- yield issue, {} # empty metadata
-
- @staticmethod
- def _get_next_link_from_header(response):
- """
- Given a response from a paginated request, extracts the URL of the next page.
-
- Example:
- response.headers.get("link") = '; rel="next", ; rel="last"'
- get_next_link_from_header(response) = 'https://api.github.com/repositories/2503910/issues?per_page=10&page=2'
- """
- link_header = response.headers.get("link")
- if link_header:
- links = link_header.split(", ")
- for link in links:
- url, rel = link.split("; ")
- url = url[1:-1] # The URL is enclosed in angle brackets
- rel = rel[5:-1] # e.g. rel="next" -> next
- if rel == "next":
- return url
- return None
-
- def _get_page_of_issues(self, url):
- """Downloads a single page of issues. Note that GitHub uses pagination for long lists of objects."""
- return requests.get(
- url,
- headers={
- "Authorization": f"Bearer {self.access_token}",
- "X-GitHub-Api-Version": "2022-11-28",
- },
- )
-
- def _get_comments(self, comments_url) -> List[GitHubIssueComment]:
- """Downloads all the comments associated with an issue; returns an empty list if the request times out."""
- try:
- response = requests.get(
- comments_url,
- headers={
- "Authorization": f"Bearer {self.access_token}",
- "X-GitHub-Api-Version": "2022-11-28",
- },
- )
- except requests.exceptions.ConnectTimeout:
- logging.warn(f"Timeout fetching comments from {comments_url}")
- return []
- comments = []
- for comment in response.json():
- comments.append(
- GitHubIssueComment(
- url=comment["url"],
- html_url=comment["html_url"],
- body=comment["body"],
- )
- )
- return comments
-
-
-@dataclass
-class IssueChunk(Chunk):
- """A chunk form a GitHub issue with a contiguous (sub)set of comments.
-
- Note that, in comparison to FileChunk, its properties are not cached. We want to allow fields to be changed in place
- and have e.g. the token count be recomputed. Compared to files, GitHub issues are typically smaller, so the overhead
- is less problematic.
- """
-
- issue: GitHubIssue
- start_comment: int
- end_comment: int # exclusive
-
- @property
- def content(self) -> str:
- """The title of the issue, followed by the comments in the chunk."""
- if self.start_comment == 0:
- # This is the first subsequence of comments. We'll include the entire body of the issue.
- issue_str = self.issue.pretty
- else:
- # This is a middle subsequence of comments. We'll only include the title of the issue.
- issue_str = f"# Issue: {self.issue.title}"
- # Now add the comments themselves.
- comments = self.issue.comments[self.start_comment : self.end_comment]
- comments_str = "\n\n".join([comment.pretty for comment in comments])
- return issue_str + "\n\n" + comments_str
-
- @property
- def metadata(self):
- """Converts the chunk to a dictionary that can be passed to a vector store."""
- return {
- "id": f"{self.issue.html_url}_{self.start_comment}_{self.end_comment}",
- "url": self.issue.html_url,
- "start_comment": self.start_comment,
- "end_comment": self.end_comment,
- # Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
- # size limit. In that case, you can simply store the start/end comment indices above, and fetch the
- # content of the issue on demand from the URL.
- TEXT_FIELD: self.content,
- }
-
- @property
- def num_tokens(self):
- """Number of tokens in this chunk."""
- return len(tokenizer.encode(self.content, disallowed_special=()))
-
-
-class GitHubIssuesChunker(Chunker):
- """Chunks a GitHub issue into smaller pieces of contiguous (sub)sets of comments."""
-
- def __init__(self, max_tokens: int):
- self.max_tokens = max_tokens
-
- def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
- """Chunks a GitHub issue into subsequences of comments."""
- del metadata # The metadata of the input issue is unused.
-
- issue = content # Rename for clarity.
- if not isinstance(issue, GitHubIssue):
- raise ValueError(f"Expected a GitHubIssue, got {type(issue)}.")
-
- chunks = []
-
- # First, create a chunk for the body of the issue. If it's too long, then truncate it.
- if len(tokenizer.encode(issue.pretty, disallowed_special=())) > self.max_tokens:
- title_len = len(tokenizer.encode(issue.title, disallowed_special=()))
- target_body_len = self.max_tokens - title_len - 20 # 20 for buffer
- trimmed_body = tokenizer.decode(tokenizer.encode(issue.body, disallowed_special=())[:target_body_len])
- trimmed_issue = GitHubIssue(
- url=issue.url,
- html_url=issue.html_url,
- title=issue.title,
- body=trimmed_body,
- comments=issue.comments,
- )
- issue_body_chunk = IssueChunk(trimmed_issue, 0, 0)
- else:
- issue_body_chunk = IssueChunk(issue, 0, 0)
-
- chunks.append(issue_body_chunk)
-
- for comment_idx, comment in enumerate(issue.comments):
- # This is just approximate, because when we actually add a comment to the chunk there might be some extra
- # tokens, like a "Comment:" prefix.
- approx_comment_size = len(tokenizer.encode(comment.body, disallowed_special=())) + 20 # 20 for buffer
-
- if chunks[-1].num_tokens + approx_comment_size > self.max_tokens:
- # Create a new chunk starting from this comment.
- chunks.append(
- IssueChunk(
- issue=issue,
- start_comment=comment_idx,
- end_comment=comment_idx + 1,
- )
- )
- else:
- # Add the comment to the existing chunk.
- chunks[-1].end_comment = comment_idx + 1
- return chunks
diff --git a/sage/index.py b/sage/index.py
deleted file mode 100644
index 3483f904a087db5b0d02a1efa9886f29333574d8..0000000000000000000000000000000000000000
--- a/sage/index.py
+++ /dev/null
@@ -1,116 +0,0 @@
-"""Runs a batch job to compute embeddings for an entire repo and stores them into a vector store."""
-
-import logging
-import os
-import time
-
-import configargparse
-
-import sage.config as sage_config
-from sage.chunker import UniversalFileChunker
-from sage.data_manager import GitHubRepoManager
-from sage.embedder import build_batch_embedder_from_flags
-from sage.github import GitHubIssuesChunker, GitHubIssuesManager
-from sage.vector_store import VectorStoreProvider, build_vector_store_from_args
-
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger()
-logger.setLevel(logging.INFO)
-
-
-def main():
- parser = configargparse.ArgParser(
- description="Batch-embeds a GitHub repository and its issues.", ignore_unknown_config_file_keys=True
- )
- sage_config.add_config_args(parser)
-
- arg_validators = [
- sage_config.add_repo_args(parser),
- sage_config.add_embedding_args(parser),
- sage_config.add_vector_store_args(parser),
- sage_config.add_indexing_args(parser),
- ]
-
- args = parser.parse_args()
-
- for validator in arg_validators:
- validator(args)
-
- if args.llm_retriever:
- logging.warning("The LLM retriever does not require indexing, so this script is a no-op.")
- return
-
- # Additionally validate embedder and vector store compatibility.
- vector_store_providers = [member.value for member in VectorStoreProvider]
- if args.embedding_provider == "openai" and args.vector_store_provider not in vector_store_providers:
- parser.error(
- f"When using OpenAI embedder, the vector store type must be from the list {vector_store_providers}."
- )
- if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
- parser.error("When using the marqo embedder, the vector store type must also be marqo.")
-
- ######################
- # Step 1: Embeddings #
- ######################
-
- # Index the repository.
- repo_embedder = None
- if args.index_repo:
- logging.info("Cloning the repository...")
- repo_manager = GitHubRepoManager.from_args(args)
- logging.info("Embedding the repo...")
- chunker = UniversalFileChunker(max_tokens=args.tokens_per_chunk)
- repo_embedder = build_batch_embedder_from_flags(repo_manager, chunker, args)
- repo_jobs_file = repo_embedder.embed_dataset(args.chunks_per_batch, args.max_embedding_jobs)
-
- # Index the GitHub issues.
- issues_embedder = None
- if args.index_issues:
- logging.info("Issuing embedding jobs for GitHub issues...")
- issues_manager = GitHubIssuesManager(
- args.repo_id, access_token=os.getenv("GITHUB_TOKEN"), index_comments=args.index_issue_comments
- )
- issues_manager.download()
- logging.info("Embedding GitHub issues...")
- chunker = GitHubIssuesChunker(max_tokens=args.tokens_per_chunk)
- issues_embedder = build_batch_embedder_from_flags(issues_manager, chunker, args)
- issues_jobs_file = issues_embedder.embed_dataset(args.chunks_per_batch, args.max_embedding_jobs)
-
- ########################
- # Step 2: Vector Store #
- ########################
-
- if args.vector_store_provider == "marqo":
- # Marqo computes embeddings and stores them in the vector store at once, so we're done.
- logging.info("Done!")
- return
-
- if repo_embedder is not None:
- logging.info("Waiting for repo embeddings to be ready...")
- while not repo_embedder.embeddings_are_ready(repo_jobs_file):
- logging.info("Sleeping for 30 seconds...")
- time.sleep(30)
-
- logging.info("Moving embeddings to the repo vector store...")
- repo_vector_store = build_vector_store_from_args(args, repo_manager)
- repo_vector_store.ensure_exists()
- repo_vector_store.upsert(repo_embedder.download_embeddings(repo_jobs_file), namespace=args.index_namespace)
-
- if issues_embedder is not None:
- logging.info("Waiting for issue embeddings to be ready...")
- while not issues_embedder.embeddings_are_ready(issues_jobs_file):
- logging.info("Sleeping for 30 seconds...")
- time.sleep(30)
-
- logging.info("Moving embeddings to the issues vector store...")
- issues_vector_store = build_vector_store_from_args(args, issues_manager)
- issues_vector_store.ensure_exists()
- issues_vector_store.upsert(
- issues_embedder.download_embeddings(issues_jobs_file), namespace=args.index_namespace
- )
-
- logging.info("Done!")
-
-
-if __name__ == "__main__":
- main()
diff --git a/sage/llm.py b/sage/llm.py
deleted file mode 100644
index 8b334604584763a2a8c2b4666a072b9c7a44f392..0000000000000000000000000000000000000000
--- a/sage/llm.py
+++ /dev/null
@@ -1,21 +0,0 @@
-import os
-
-from langchain_anthropic import ChatAnthropic
-from langchain_ollama import ChatOllama
-from langchain_openai import ChatOpenAI
-
-
-def build_llm_via_langchain(provider: str, model: str):
- """Builds a language model via LangChain."""
- if provider == "openai":
- if "OPENAI_API_KEY" not in os.environ:
- raise ValueError("Please set the OPENAI_API_KEY environment variable.")
- return ChatOpenAI(model=model or "gpt-4")
- elif provider == "anthropic":
- if "ANTHROPIC_API_KEY" not in os.environ:
- raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.")
- return ChatAnthropic(model=model or "claude-3-opus-20240229")
- elif provider == "ollama":
- return ChatOllama(model=model or "llama3.1")
- else:
- raise ValueError(f"Unrecognized LLM provider {provider}. Contributons are welcome!")
diff --git a/sage/reranker.py b/sage/reranker.py
deleted file mode 100644
index 0aa2c37aabcad8abacaad837833cb5a24228322f..0000000000000000000000000000000000000000
--- a/sage/reranker.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import os
-from enum import Enum
-from typing import Optional
-
-from langchain.retrievers.document_compressors import CrossEncoderReranker
-from langchain_cohere import CohereRerank
-from langchain_community.cross_encoders import HuggingFaceCrossEncoder
-from langchain_community.document_compressors import JinaRerank
-from langchain_core.documents import BaseDocumentCompressor
-from langchain_nvidia_ai_endpoints import NVIDIARerank
-from langchain_voyageai import VoyageAIRerank
-
-
-class RerankerProvider(Enum):
- NONE = "none"
- HUGGINGFACE = "huggingface"
- COHERE = "cohere"
- NVIDIA = "nvidia"
- JINA = "jina"
- VOYAGE = "voyage"
-
-
-def build_reranker(provider: str, model: Optional[str] = None, top_k: int = 5) -> Optional[BaseDocumentCompressor]:
- if provider == RerankerProvider.NONE.value:
- return None
-
- api_key_env_vars = {
- RerankerProvider.COHERE.value: "COHERE_API_KEY",
- RerankerProvider.NVIDIA.value: "NVIDIA_API_KEY",
- RerankerProvider.JINA.value: "JINA_API_KEY",
- RerankerProvider.VOYAGE.value: "VOYAGE_API_KEY",
- }
-
- provider_defaults = {
- RerankerProvider.HUGGINGFACE.value: "cross-encoder/ms-marco-MiniLM-L-6-v2",
- RerankerProvider.COHERE.value: "rerank-english-v3.0",
- RerankerProvider.NVIDIA.value: "nvidia/nv-rerankqa-mistral-4b-v3",
- RerankerProvider.VOYAGE.value: "rerank-1",
- }
-
- model = model or provider_defaults.get(provider)
-
- if provider == RerankerProvider.HUGGINGFACE.value:
- encoder_model = HuggingFaceCrossEncoder(model_name=model)
- return CrossEncoderReranker(model=encoder_model, top_n=top_k)
-
- if provider in api_key_env_vars:
- api_key = os.getenv(api_key_env_vars[provider])
- if not api_key:
- raise ValueError(f"Please set the {api_key_env_vars[provider]} environment variable")
-
- if provider == RerankerProvider.COHERE.value:
- return CohereRerank(model=model, cohere_api_key=api_key, top_n=top_k)
-
- if provider == RerankerProvider.NVIDIA.value:
- return NVIDIARerank(model=model, api_key=api_key, top_n=top_k, truncate="END")
-
- if provider == RerankerProvider.JINA.value:
- return JinaRerank(top_n=top_k)
-
- if provider == RerankerProvider.VOYAGE.value:
- return VoyageAIRerank(model=model, api_key=api_key, top_k=top_k)
-
- raise ValueError(f"Invalid reranker provider: {provider}")
diff --git a/sage/retriever.py b/sage/retriever.py
deleted file mode 100644
index e718134b1733f33143a5b0dce0bcef525520e0cb..0000000000000000000000000000000000000000
--- a/sage/retriever.py
+++ /dev/null
@@ -1,352 +0,0 @@
-import logging
-import os
-from typing import Dict, List, Optional
-
-import anthropic
-import Levenshtein
-from anytree import Node, RenderTree
-from langchain.callbacks.manager import CallbackManagerForRetrieverRun
-from langchain.retrievers import ContextualCompressionRetriever
-from langchain.retrievers.multi_query import MultiQueryRetriever
-from langchain.schema import BaseRetriever, Document
-from langchain_google_genai import GoogleGenerativeAIEmbeddings
-from langchain_openai import OpenAIEmbeddings
-from langchain_voyageai import VoyageAIEmbeddings
-from pydantic import Field
-
-from sage.code_symbols import get_code_symbols
-from sage.data_manager import DataManager, GitHubRepoManager
-from sage.llm import build_llm_via_langchain
-from sage.reranker import build_reranker
-from sage.vector_store import build_vector_store_from_args
-
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger()
-logger.setLevel(logging.INFO)
-
-CLAUDE_MODEL = "claude-3-5-sonnet-20241022"
-CLAUDE_MODEL_CONTEXT_SIZE = 200_000
-
-
-class LLMRetriever(BaseRetriever):
- """Custom Langchain retriever based on an LLM.
-
- Builds a representation of the folder structure of the repo, feeds it to an LLM, and asks the LLM for the most
- relevant files for a particular user query, expecting it to make decisions based solely on file names.
-
- Only works with Claude/Anthropic, because it's very slow (e.g. 15s for a mid-sized codebase) and we need prompt
- caching to make it usable.
- """
-
- repo_manager: GitHubRepoManager = Field(...)
- top_k: int = Field(...)
-
- cached_repo_metadata: List[Dict] = Field(...)
- cached_repo_files: List[str] = Field(...)
- cached_repo_hierarchy: str = Field(...)
-
- def __init__(self, repo_manager: GitHubRepoManager, top_k: int):
- super().__init__()
- self.repo_manager = repo_manager
- self.top_k = top_k
-
- # We cached these fields manually because:
- # 1. Pydantic doesn't work with functools's @cached_property.
- # 2. We can't use Pydantic's @computed_field because these fields depend on each other.
- # 3. We can't use functools's @lru_cache because LLMRetriever needs to be hashable.
- self.cached_repo_metadata = None
- self.cached_repo_files = None
- self.cached_repo_hierarchy = None
-
- if not os.environ.get("ANTHROPIC_API_KEY"):
- raise ValueError("Please set the ANTHROPIC_API_KEY environment variable for the LLMRetriever.")
-
- @property
- def repo_metadata(self):
- if not self.cached_repo_metadata:
- self.cached_repo_metadata = [metadata for metadata in self.repo_manager.walk(get_content=False)]
-
- # Extracting code symbols takes quite a while, since we need to read each file from disk.
- # As a compromise, we do it for small codebases only.
- small_codebase = len(self.repo_files) <= 200
- if small_codebase:
- for metadata in self.cached_repo_metadata:
- file_path = metadata["file_path"]
- content = self.repo_manager.read_file(file_path)
- metadata["code_symbols"] = get_code_symbols(file_path, content)
-
- return self.cached_repo_metadata
-
- @property
- def repo_files(self):
- if not self.cached_repo_files:
- self.cached_repo_files = set(metadata["file_path"] for metadata in self.repo_metadata)
- return self.cached_repo_files
-
- @property
- def repo_hierarchy(self):
- """Produces a string that describes the structure of the repository. Depending on how big the codebase is, it
- might include class and method names."""
- if self.cached_repo_hierarchy is None:
- render = LLMRetriever._render_file_hierarchy(self.repo_metadata, include_classes=True, include_methods=True)
- max_tokens = CLAUDE_MODEL_CONTEXT_SIZE - 50_000 # 50,000 tokens for other parts of the prompt.
- client = anthropic.Anthropic()
-
- def count_tokens(x):
- count = client.beta.messages.count_tokens(model=CLAUDE_MODEL, messages=[{"role": "user", "content": x}])
- return count.input_tokens
-
- if count_tokens(render) > max_tokens:
- logging.info("File hierarchy is too large; excluding methods.")
- render = LLMRetriever._render_file_hierarchy(
- self.repo_metadata, include_classes=True, include_methods=False
- )
- if count_tokens(render) > max_tokens:
- logging.info("File hierarchy is still too large; excluding classes.")
- render = LLMRetriever._render_file_hierarchy(
- self.repo_metadata, include_classes=False, include_methods=False
- )
- if count_tokens(render) > max_tokens:
- logging.info("File hierarchy is still too large; truncating.")
- tokenizer = anthropic.Tokenizer()
- tokens = tokenizer.tokenize(render)[:max_tokens]
- render = tokenizer.detokenize(tokens)
- self.cached_repo_hierarchy = render
- return self.cached_repo_hierarchy
-
- def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
- """Retrieve relevant documents for a given query."""
- filenames = self._ask_llm_to_retrieve(user_query=query, top_k=self.top_k)
- documents = []
- for filename in filenames:
- document = Document(
- page_content=self.repo_manager.read_file(filename),
- metadata={"file_path": filename, "url": self.repo_manager.url_for_file(filename)},
- )
- documents.append(document)
- return documents
-
- def _ask_llm_to_retrieve(self, user_query: str, top_k: int) -> List[str]:
- """Feeds the file hierarchy and user query to the LLM and asks which files might be relevant."""
- repo_hierarchy = str(self.repo_hierarchy)
- sys_prompt = f"""
-You are a retriever system. You will be given a user query and a list of files in a GitHub repository, together with the class names in each file.
-
-For instance:
-folder1
- folder2
- folder3
- file123.py
- ClassName1
- ClassName2
- ClassName3
-means that there is a file with path folder1/folder2/folder3/file123.py, which contains classes ClassName1, ClassName2, and ClassName3.
-
-Your task is to determine the top {top_k} files that are most relevant to the user query.
-DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
-
-Here is the file hierarchy of the GitHub repository, together with the class names in each file:
-
-{repo_hierarchy}
-"""
-
- # We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here.
- augmented_user_query = f"""
-User query: {user_query}
-
-DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths.
-"""
- response = LLMRetriever._call_via_anthropic_with_prompt_caching(sys_prompt, augmented_user_query)
-
- files_from_llm = response.content[0].text.strip().split("\n")
- validated_files = []
-
- for filename in files_from_llm:
- if filename not in self.repo_files:
- if "/" not in filename:
- # This is most likely some natural language excuse from the LLM; skip it.
- continue
- # Try a few heuristics to fix the filename.
- filename = LLMRetriever._fix_filename(filename, self.repo_manager.repo_id)
- if filename not in self.repo_files:
- # The heuristics failed; try to find the closest filename in the repo.
- filename = LLMRetriever._find_closest_filename(filename, self.repo_files)
- if filename in self.repo_files:
- validated_files.append(filename)
- return validated_files
-
- @staticmethod
- def _call_via_anthropic_with_prompt_caching(system_prompt: str, user_prompt: str) -> str:
- """Calls the Anthropic API with prompt caching for the system prompt.
-
- See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching.
-
- We're circumventing LangChain for now, because the feature is < 1 week old at the time of writing and has no
- documentation: https://github.com/langchain-ai/langchain/pull/27087
- """
- system_message = {"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}
- user_message = {"role": "user", "content": user_prompt}
-
- response = anthropic.Anthropic().beta.prompt_caching.messages.create(
- model=CLAUDE_MODEL,
- max_tokens=1024, # The maximum number of *output* tokens to generate.
- system=[system_message],
- messages=[user_message],
- )
- # Caching information will be under `cache_creation_input_tokens` and `cache_read_input_tokens`.
- # Note that, for prompts shorter than 1024 tokens, Anthropic will not do any caching.
- logging.info("Anthropic prompt caching info: %s", response.usage)
- return response
-
- @staticmethod
- def _render_file_hierarchy(
- repo_metadata: List[Dict], include_classes: bool = True, include_methods: bool = True
- ) -> str:
- """Given a list of files, produces a visualization of the file hierarchy. This hierarchy optionally includes
- class and method names, if available.
-
- For large codebases, including both classes and methods might exceed the token limit of the LLM. In that case,
- try setting `include_methods=False` first. If that's still too long, try also setting `include_classes=False`.
-
- As a point of reference, the Transformers library requires setting `include_methods=False` to fit within
- Claude's 200k context.
-
- Example:
- folder1
- folder11
- file111.md
- file112.py
- ClassName1
- method_name1
- method_name2
- method_name3
- folder12
- file121.py
- ClassName2
- ClassName3
- folder2
- file21.py
- """
- # The "nodepath" is the path from root to the node (e.g. huggingface/transformers/examples)
- nodepath_to_node = {}
-
- for metadata in repo_metadata:
- path = metadata["file_path"]
- paths = [path]
-
- if include_classes or include_methods:
- # Add the code symbols to the path. For instance, "folder/myfile.py/ClassName/method_name".
- for class_name, method_name in metadata.get("code_symbols", []):
- if include_classes and class_name:
- paths.append(path + "/" + class_name)
- # We exclude private methods to save tokens.
- if include_methods and method_name and not method_name.startswith("_"):
- paths.append(
- path + "/" + class_name + "/" + method_name if class_name else path + "/" + method_name
- )
-
- for path in paths:
- items = path.split("/")
- nodepath = ""
- parent_node = None
- for item in items:
- nodepath = f"{nodepath}/{item}"
- if nodepath in nodepath_to_node:
- node = nodepath_to_node[nodepath]
- else:
- node = Node(item, parent=parent_node)
- nodepath_to_node[nodepath] = node
- parent_node = node
-
- root_path = "/" + repo_metadata[0]["file_path"].split("/")[0]
- full_render = ""
- root_node = nodepath_to_node[root_path]
- for pre, fill, node in RenderTree(root_node):
- render = "%s%s\n" % (pre, node.name)
- # Replace special lines with empty strings to save on tokens.
- render = render.replace("└", " ").replace("├", " ").replace("│", " ").replace("─", " ")
- full_render += render
- return full_render
-
- @staticmethod
- def _fix_filename(filename: str, repo_id: str) -> str:
- """Attempts to "fix" a filename output by the LLM.
-
- Common issues with LLM-generated filenames:
- - The LLM prepends an extraneous "/".
- - The LLM omits the name of the org (e.g. "transformers/README.md" instead of "huggingface/transformers/README.md").
- - The LLM omits the name of the repo (e.g. "huggingface/README.md" instead of "huggingface/transformers/README.md").
- - The LLM omits the org/repo prefix (e.g. "README.md" instead of "huggingface/transformers/README.md").
- """
- if filename.startswith("/"):
- filename = filename[1:]
- org_name, repo_name = repo_id.split("/")
- items = filename.split("/")
- if filename.startswith(org_name) and not filename.startswith(repo_id):
- new_items = [org_name, repo_name] + items[1:]
- return "/".join(new_items)
- if not filename.startswith(org_name) and filename.startswith(repo_name):
- return f"{org_name}/{filename}"
- if not filename.startswith(org_name) and not filename.startswith(repo_name):
- return f"{org_name}/{repo_name}/{filename}"
- return filename
-
- @staticmethod
- def _find_closest_filename(filename: str, repo_filenames: List[str], max_edit_distance: int = 10) -> Optional[str]:
- """Returns the path in the repo with smallest edit distance from `filename`. Helpful when the `filename` was
- generated by an LLM and parts of it might have been hallucinated. Returns None if the closest path is more than
- `max_edit_distance` away. In case of a tie, returns an arbitrary closest path.
- """
- distances = [(path, Levenshtein.distance(filename, path)) for path in repo_filenames]
- distances.sort(key=lambda x: x[1])
- if distances[0][1] <= max_edit_distance:
- closest_path = distances[0][0]
- return closest_path
- return None
-
-
-class RerankerWithErrorHandling(BaseRetriever):
- """Wraps a `ContextualCompressionRetriever` to catch errors during inference.
-
- In practice, we see occasional `requests.exceptions.ReadTimeout` from the NVIDIA reranker, which crash the entire
- pipeline. This wrapper catches such exceptions by simply returning the documents in the original order.
- """
-
- def __init__(self, reranker: ContextualCompressionRetriever):
- self.reranker = reranker
-
- def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
- try:
- return self.reranker._get_relevant_documents(query, run_manager=run_manager)
- except Exception as e:
- logging.error(f"Error in reranker; preserving original document order from retriever. {e}")
- return self.reranker.base_retriever._get_relevant_documents(query, run_manager=run_manager)
-
-
-def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
- """Builds a retriever (with optional reranking) from command-line arguments."""
- if args.llm_retriever:
- retriever = LLMRetriever(GitHubRepoManager.from_args(args), top_k=args.retriever_top_k)
- else:
- if args.embedding_provider == "openai":
- embeddings = OpenAIEmbeddings(model=args.embedding_model)
- elif args.embedding_provider == "voyage":
- embeddings = VoyageAIEmbeddings(model=args.embedding_model)
- elif args.embedding_provider == "gemini":
- embeddings = GoogleGenerativeAIEmbeddings(model=args.embedding_model)
- else:
- embeddings = None
-
- retriever = build_vector_store_from_args(args, data_manager).as_retriever(
- top_k=args.retriever_top_k, embeddings=embeddings, namespace=args.index_namespace
- )
-
- if args.multi_query_retriever:
- retriever = MultiQueryRetriever.from_llm(
- retriever=retriever, llm=build_llm_via_langchain(args.llm_provider, args.llm_model)
- )
-
- reranker = build_reranker(args.reranker_provider, args.reranker_model, args.reranker_top_k)
- if reranker:
- retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
- return retriever
diff --git a/sage/sample-exclude.txt b/sage/sample-exclude.txt
deleted file mode 100644
index d1e9451e1d7d2b1b42966452a5e669d292a6e2a4..0000000000000000000000000000000000000000
--- a/sage/sample-exclude.txt
+++ /dev/null
@@ -1,94 +0,0 @@
-# This list tends to be overly-aggressive. We're assuming by default devs are most interested in code files, not configs.
-dir:_build
-dir:alembic
-dir:build
-dir:deprecated
-dir:docker
-dir:downgrades
-dir:fixtures
-dir:integration-tests
-dir:legacy
-dir:library-tests
-dir:logo
-dir:logs
-dir:migrations
-dir:node_modules
-dir:old-change-notes
-dir:test
-dir:testdata
-dir:tests
-dir:third_party
-dir:upgrades
-dir:vendor
-ext:.Packages
-ext:.avi
-ext:.bazel
-ext:.bin
-ext:.binpb
-ext:.bmp
-ext:.crt
-ext:.css
-ext:.csv
-ext:.dat
-ext:.db
-ext:.duckdb
-ext:.eot
-ext:.exe
-ext:.gguf
-ext:.gif
-ext:.glb
-ext:.gz
-ext:.icns
-ext:.ico
-ext:.inp
-ext:.isl
-ext:.jar
-ext:.jpeg
-ext:.jpg
-ext:.json
-ext:.key
-ext:.lock
-ext:.mo
-ext:.model
-ext:.mov
-ext:.mp3
-ext:.mp4
-ext:.otf
-ext:.out
-ext:.pb
-ext:.pdf
-ext:.pem
-ext:.pickle
-ext:.png
-ext:.pt
-ext:.ptl
-ext:.s
-ext:.so
-ext:.sql
-ext:.sqlite
-ext:.stl
-ext:.sum
-ext:.svg
-ext:.tar
-ext:.tgz
-ext:.th
-ext:.toml
-ext:.ts-fixture
-ext:.tsv
-ext:.ttf
-ext:.wav
-ext:.webp
-ext:.wmv
-ext:.woff
-ext:.woff2
-ext:.xml
-ext:.yaml
-ext:.yml
-ext:.zip
-file:CODE_OF_CONDUCT.md
-file:CONTRIBUTING.md
-file:Dockerfile
-file:__init__.py
-file:code-of-conduct.md
-file:conftest.py
-file:package-lock.json
diff --git a/sage/vector_store.py b/sage/vector_store.py
deleted file mode 100644
index 4dece26e53b388452ac624e991b9f2c65322398f..0000000000000000000000000000000000000000
--- a/sage/vector_store.py
+++ /dev/null
@@ -1,471 +0,0 @@
-"""Vector store abstraction and implementations."""
-
-import logging
-import os
-from abc import ABC, abstractmethod
-from enum import Enum
-from functools import cached_property
-from typing import Dict, Generator, List, Optional, Tuple
-from uuid import uuid4
-
-import chromadb
-import faiss
-import marqo
-import nltk
-from langchain.retrievers import EnsembleRetriever
-from langchain_chroma import Chroma as LangChainChroma
-from langchain_community.docstore.in_memory import InMemoryDocstore
-from langchain_community.retrievers import BM25Retriever
-from langchain_community.vectorstores import FAISS, Marqo
-from langchain_community.vectorstores import Pinecone as LangChainPinecone
-from langchain_core.documents import Document
-from langchain_core.embeddings import Embeddings
-from langchain_google_genai import GoogleGenerativeAIEmbeddings
-from langchain_milvus import Milvus
-from langchain_openai import OpenAIEmbeddings
-from langchain_qdrant import QdrantVectorStore as LangChainQdrant
-from langchain_voyageai import VoyageAIEmbeddings
-from nltk.data import find
-from pinecone import Pinecone, ServerlessSpec
-from pinecone_text.sparse import BM25Encoder
-from qdrant_client import QdrantClient
-from qdrant_client.http.models import Distance, VectorParams
-
-from sage.constants import TEXT_FIELD
-from sage.data_manager import DataManager
-
-Vector = Tuple[Dict, List[float]] # (metadata, embedding)
-
-
-class VectorStoreProvider(Enum):
- PINECONE = "pinecone"
- MARQO = "marqo"
- CHROMA = "chroma"
- FAISS = "faiss"
- MILVUS = "milvus"
- QDRANT = "qdrant"
-
-
-def is_punkt_downloaded():
- try:
- find("tokenizers/punkt_tab")
- return True
- except LookupError:
- return False
-
-
-class VectorStore(ABC):
- """Abstract class for a vector store."""
-
- @abstractmethod
- def ensure_exists(self):
- """Ensures that the vector store exists. Creates it if it doesn't."""
-
- @abstractmethod
- def upsert_batch(self, vectors: List[Vector], namespace: str):
- """Upserts a batch of vectors."""
-
- def upsert(self, vectors: Generator[Vector, None, None], namespace: str):
- """Upserts in batches of 100, since vector stores have a limit on upsert size."""
- batch = []
- for metadata, embedding in vectors:
- batch.append((metadata, embedding))
- if len(batch) == 100:
- self.upsert_batch(batch, namespace)
- batch = []
- if batch:
- self.upsert_batch(batch, namespace)
-
- @abstractmethod
- def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
- """Converts the vector store to a LangChain retriever object."""
-
-
-class PineconeVectorStore(VectorStore):
- """Vector store implementation using Pinecone."""
-
- def __init__(self, index_name: str, dimension: int, alpha: float, bm25_cache: Optional[str] = None):
- """
- Args:
- index_name: The name of the Pinecone index to use. If it doesn't exist already, we'll create it.
- dimension: The dimension of the vectors.
- alpha: The alpha parameter for hybrid search: alpha == 1.0 means pure dense search, alpha == 0.0 means pure
- BM25, and 0.0 < alpha < 1.0 means a hybrid of the two.
- bm25_cache: The path to the BM25 encoder file. If not specified, we'll use the default BM25 (fitted on the
- MS MARCO dataset).
- """
- self.index_name = index_name
- self.dimension = dimension
- self.client = Pinecone()
- self.alpha = alpha
- if alpha < 1.0:
- if bm25_cache and os.path.exists(bm25_cache):
- logging.info("Loading BM25 encoder from cache.")
- # We need nltk tokenizers for bm25 tokenization
- if is_punkt_downloaded():
- print("punkt is already downloaded")
- else:
- print("punkt is not downloaded")
- # Optionally download it
- nltk.download("punkt_tab")
- self.bm25_encoder = BM25Encoder()
- self.bm25_encoder.load(path=bm25_cache)
- else:
- logging.info("Using default BM25 encoder (fitted to MS MARCO).")
- self.bm25_encoder = BM25Encoder.default()
- else:
- self.bm25_encoder = None
-
- @cached_property
- def index(self):
- self.ensure_exists()
- index = self.client.Index(self.index_name)
-
- # Hack around the fact that PineconeRetriever expects the content of the chunk to be in a "text" field,
- # while PineconeHybridSearchRetrieve expects it to be in a "context" field.
- original_query = index.query
-
- def patched_query(*args, **kwargs):
- result = original_query(*args, **kwargs)
- for res in result["matches"]:
- if TEXT_FIELD in res["metadata"]:
- res["metadata"]["context"] = res["metadata"][TEXT_FIELD]
- return result
-
- index.query = patched_query
- return index
-
- def ensure_exists(self):
- if self.index_name not in self.client.list_indexes().names():
- self.client.create_index(
- name=self.index_name,
- dimension=self.dimension,
- # See https://www.pinecone.io/learn/hybrid-search-intro/
- metric="dotproduct" if self.bm25_encoder else "cosine",
- spec=ServerlessSpec(cloud="aws", region="us-east-1"),
- )
-
- def upsert_batch(self, vectors: List[Vector], namespace: str):
- pinecone_vectors = []
- for i, (metadata, embedding) in enumerate(vectors):
- vector = {"id": metadata.get("id", str(i)), "values": embedding, "metadata": metadata}
- if self.bm25_encoder:
- vector["sparse_values"] = self.bm25_encoder.encode_documents(metadata[TEXT_FIELD])
- pinecone_vectors.append(vector)
-
- self.index.upsert(vectors=pinecone_vectors, namespace=namespace)
-
- def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
- bm25_retriever = (
- BM25Retriever(
- embeddings=embeddings,
- sparse_encoder=self.bm25_encoder,
- index=self.index,
- namespace=namespace,
- top_k=top_k,
- )
- if self.bm25_encoder
- else None
- )
-
- dense_retriever = LangChainPinecone.from_existing_index(
- index_name=self.index_name, embedding=embeddings, namespace=namespace
- ).as_retriever(search_kwargs={"k": top_k})
-
- if bm25_retriever:
- return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1 - self.alpha])
- else:
- return dense_retriever
-
-
-class ChromaVectorStore(VectorStore):
- """Vector store implementation using ChromaDB"""
-
- def __init__(self, index_name: str, alpha: float = None, bm25_cache: Optional[str] = None):
- """
- Args:
- index_name: The name of the Chroma collection/index to use. If it doesn't exist already, we'll create it.
- alpha: The alpha parameter for hybrid search: alpha == 1.0 means pure dense search, alpha == 0.0 means pure
- BM25, and 0.0 < alpha < 1.0 means a hybrid of the two.
- """
- self.index_name = index_name
- self.alpha = alpha
- self.client = chromadb.PersistentClient()
-
- @cached_property
- def index(self):
- index = self.client.get_or_create_collection(self.index_name)
- return index
-
- def ensure_exists(self):
- pass
-
- def upsert_batch(self, vectors: List[Vector], namespace: str):
- del namespace
-
- ids = []
- embeddings = []
- metadatas = []
- documents = []
-
- for i, (metadata, embedding) in enumerate(vectors):
- ids.append(metadata.get("id", str(i)))
- embeddings.append(embedding)
- metadatas.append(metadata)
- documents.append(metadata[TEXT_FIELD])
-
- self.index.upsert(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents)
-
- def as_retriever(self, top_k: int, embeddings: Embeddings = None, namespace: str = None):
- vector_store = LangChainChroma(
- collection_name=self.index_name, embedding_function=embeddings, client=self.client
- )
-
- return vector_store.as_retriever(search_kwargs={"k": top_k})
-
-
-class FAISSVectorStore(VectorStore):
- """Vector store implementation using FAISS"""
-
- def __init__(self, index_name: str, dimension: int, embeddings: Embeddings = None):
- """
- Args:
- index_name: The name of the FAISS index to use. If it doesn't exist already, we'll create it.
- dimension: The dimension of the vectors.
- embeddings: The embedding function used to generate embeddings
- """
- self.index_name = index_name
- self.dimension = dimension
- self.embeddings = embeddings
-
- # check if the index exists
- if os.path.exists(self.index_name):
- # load the existing index
- self.vector_store = FAISS.load_local(
- folder_path=self.index_name, embeddings=self.embeddings, allow_dangerous_deserialization=True
- )
- # else create a new index
- else:
- self.vector_store = FAISS(
- embedding_function=self.embeddings,
- index=self.index,
- docstore=InMemoryDocstore(),
- index_to_docstore_id={},
- )
-
- @cached_property
- def index(self):
- index = faiss.IndexFlatL2(self.dimension)
- return index
-
- def ensure_exists(self):
- pass
-
- def upsert_batch(self, vectors: List[Vector], namespace: str):
- del namespace
-
- ids = []
- documents = []
-
- for i, (meta_data, embedding) in enumerate(vectors):
- ids.append(meta_data.get("id", str(i)))
- document = Document(page_content=meta_data[TEXT_FIELD], metadata=meta_data)
- documents.append(document)
-
- self.vector_store.add_documents(documents=documents, ids=ids)
-
- # saving the index after every batch upsert
- self.vector_store.save_local(self.index_name)
- print("Save Local Executed")
- logging.error("Save Local Got Executed")
-
- def as_retriever(self, top_k, embeddings, namespace):
- del embeddings
- del namespace
-
- return self.vector_store.as_retriever(search_kwards={"k": top_k})
-
-
-class MilvusVectorStore(VectorStore):
- """Vector store implementation using Milvus"""
-
- def __init__(self, uri: str, index_name: str, embeddings: Embeddings = None):
- """
- Args:
- index_name: The name of the Milvus collection to use. If it doesn't exist already, we'll create it.
- embeddings: The embedding function used to generate embeddings
- """
- self.uri = uri
- self.index_name = index_name
- self.embeddings = embeddings
-
- self.vector_store = Milvus(
- embedding_function=embeddings, connection_args={"uri": self.uri}, collection_name=self.index_name
- )
-
- def ensure_exists(self):
- pass
-
- def upsert_batch(self, vectors: List[Vector], namespace: str):
- del namespace
-
- ids = []
- documents = []
-
- for i, (meta_data, embedding) in enumerate(vectors):
- ids.append(meta_data.get("id", str(i)))
- # "text" is a reserved keyword. So removing it
- page_content = meta_data[TEXT_FIELD]
- meta_data["content"] = meta_data[TEXT_FIELD]
- del meta_data[TEXT_FIELD]
-
- document = Document(page_content=page_content, metadata=meta_data)
- documents.append(document)
-
- self.vector_store.add_documents(documents=documents, ids=ids)
-
- def as_retriever(self, top_k, embeddings, namespace):
- del embeddings
- del namespace
-
- return self.vector_store.as_retriever(search_kwards={"k": top_k})
-
-
-class QdrantVectorStore(VectorStore):
- """Vector store implementation using Qdrant"""
-
- def __init__(self, index_name: str, dimension: int, embeddings: Embeddings = None):
- """
- Args:
- index_name: The name of the Qdrant collection to use. If it doesn't exist already, we'll create it.
- embeddings: The embedding function used to generate embeddings
- """
- self.index_name = index_name
- self.dimension = dimension
- self.embeddings = embeddings
- self.client = QdrantClient(path="qdrantdb")
- self.vector_store = self.index
-
- @cached_property
- def index(self):
- self.ensure_exists()
- vector_store = LangChainQdrant(client=self.client, collection_name=self.index_name, embedding=self.embeddings)
- return vector_store
-
- def ensure_exists(self):
- if not self.client.collection_exists(self.index_name):
- self.client.create_collection(
- collection_name=self.index_name,
- vectors_config=VectorParams(size=self.dimension, distance=Distance.COSINE),
- )
-
- def upsert_batch(self, vectors: List[Vector], namespace: str):
- del namespace
-
- ids = []
- documents = []
-
- for i, (meta_data, embedding) in enumerate(vectors):
- ids.append(str(uuid4()))
- document = Document(page_content=meta_data[TEXT_FIELD], metadata=meta_data)
- documents.append(document)
-
- self.vector_store.add_documents(documents=documents, ids=ids)
-
- def as_retriever(self, top_k, embeddings, namespace):
- del embeddings
- del namespace
-
- return self.vector_store.as_retriever(search_kwards={"k": top_k})
-
-
-class MarqoVectorStore(VectorStore):
- """Vector store implementation using Marqo."""
-
- def __init__(self, url: str, index_name: str):
- self.client = marqo.Client(url=url)
- self.index_name = index_name
-
- def ensure_exists(self):
- pass
-
- def upsert_batch(self, vectors: List[Vector], namespace: str):
- # Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
- pass
-
- def as_retriever(self, top_k: int, embeddings: Embeddings = None, namespace: str = None):
- del embeddings # Unused; The Marqo vector store is also an embedder.
- del namespace # Unused; Unlike Pinecone, Marqo doesn't differentiate between index name and namespace.
-
- vectorstore = Marqo(client=self.client, index_name=self.index_name)
-
- # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
- # the result, and instead take the "filename" directly from the result.
- def patched_method(self, results):
- documents: List[Document] = []
- for result in results["hits"]:
- content = result.pop(TEXT_FIELD)
- documents.append(Document(page_content=content, metadata=result))
- return documents
-
- vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
- vectorstore, vectorstore.__class__
- )
- return vectorstore.as_retriever(search_kwargs={"k": top_k})
-
-
-def build_vector_store_from_args(
- args: dict,
- data_manager: Optional[DataManager] = None,
-) -> VectorStore:
- """Builds a vector store from the given command-line arguments.
-
- When `data_manager` is specified and hybrid retrieval is requested, we'll use it to fit a BM25 encoder on the corpus
- of documents.
- """
- if args.embedding_provider == "openai":
- embeddings = OpenAIEmbeddings(model=args.embedding_model)
- elif args.embedding_provider == "voyage":
- embeddings = VoyageAIEmbeddings(model=args.embedding_model)
- elif args.embedding_provider == "gemini":
- embeddings = GoogleGenerativeAIEmbeddings(model=args.embedding_model)
-
- if args.vector_store_provider == "pinecone":
- bm25_cache = os.path.join(".bm25_cache", args.index_namespace, "bm25_encoder.json")
- if args.retrieval_alpha < 1.0 and not os.path.exists(bm25_cache) and data_manager:
- logging.info("Fitting BM25 encoder on the corpus...")
- if is_punkt_downloaded():
- print("punkt is already downloaded")
- else:
- print("punkt is not downloaded")
- # Optionally download it
- nltk.download("punkt_tab")
- corpus = [content for content, _ in data_manager.walk()]
- bm25_encoder = BM25Encoder()
- bm25_encoder.fit(corpus)
- # Make sure the folder exists, before we dump the encoder.
- bm25_folder = os.path.dirname(bm25_cache)
- if not os.path.exists(bm25_folder):
- os.makedirs(bm25_folder)
- bm25_encoder.dump(bm25_cache)
-
- return PineconeVectorStore(
- index_name=args.index_name,
- dimension=args.embedding_size if "embedding_size" in args else None,
- alpha=args.retrieval_alpha,
- bm25_cache=bm25_cache,
- )
- elif args.vector_store_provider == "chroma":
- return ChromaVectorStore(
- index_name=args.index_name,
- )
- elif args.vector_store_provider == "faiss":
- return FAISSVectorStore(index_name=args.index_name, dimension=args.embedding_size, embeddings=embeddings)
- elif args.vector_store_provider == "milvus":
- return MilvusVectorStore(uri=args.milvus_uri, index_name=args.index_name, embeddings=embeddings)
- elif args.vector_store_provider == "qdrant":
- return QdrantVectorStore(index_name=args.index_name, dimension=args.embedding_size, embeddings=embeddings)
- elif args.vector_store_provider == "marqo":
- return MarqoVectorStore(url=args.marqo_url, index_name=args.index_namespace)
- else:
- raise ValueError(f"Unrecognized vector store type {args.vector_store_provider}")