import logging import os from typing import Any, Dict, List, Optional, Set from anytree import Node, RenderTree from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever from langchain_core.language_models import BaseChatModel from langchain_core.messages import SystemMessage, HumanMessage from pydantic import PrivateAttr import Levenshtein logger = logging.getLogger(__name__) class LLMRetriever(BaseRetriever): """ Retriever that uses an LLM to select relevant files from the project structure. Adapted from generic Sage implementation to work with LangChain models. """ llm: BaseChatModel repo_files: List[str] top_k: int = 5 repo_structure: str = "" def __init__(self, **kwargs): super().__init__(**kwargs) # Use object.__setattr__ to avoid pydantic validation errors if frozen # But since we made it a field, we can just set it OR pass it in kwargs if calculated before. # Better: calculate it here and set it. structure = self._build_repo_structure(self.repo_files) self.repo_structure = structure def _build_repo_structure(self, files: List[str]) -> str: """Builds a visual tree structure of the repository.""" # Build tree root = Node("root") nodes = {"": root} for file_path in files: parts = file_path.strip("/").split("/") current_path = "" parent = root for part in parts: current_path = f"{current_path}/{part}" if current_path else part if current_path not in nodes: nodes[current_path] = Node(part, parent=parent) parent = nodes[current_path] # Render tree render = "" for pre, _, node in RenderTree(root): if node.name == "root": continue # Simplify characters for token efficiency line = f"{pre}{node.name}" line = line.replace("└", " ").replace("├", " ").replace("│", " ").replace("─", " ") render += line + "\n" return render def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]: """Retrieve relevant documents for a given query.""" try: logger.info("LLMRetriever: Asking LLM to select files...") filenames = self._ask_llm_to_retrieve(query) logger.info(f"LLMRetriever: Selected {len(filenames)} files: {filenames}") documents = [] for filename in filenames: # We expect the caller to handle reading the actual content if needed, # or we return a Document with just metadata if we don't have access to the file system here. # Ideally, we should have access to read the file. # Let's assume we can read if it is a local path (which it should be in this app). # Check if we can find the absolute path? # The repo_files passed in might be relative paths or absolute. # We will assume they are paths we can open. try: # If repo_files are absolute, great. If relative, we might need a base_dir. # For now, let's assume the passed repo_files are valid paths to read. if os.path.exists(filename): with open(filename, "r", errors='ignore') as f: content = f.read() documents.append(Document( page_content=content, metadata={"file_path": filename, "source": "llm_retriever"} )) else: documents.append(Document( page_content="", metadata={"file_path": filename, "source": "llm_retriever", "error": "File not found"} )) except Exception as e: logger.warning(f"Failed to read file {filename}: {e}") return documents except Exception as e: logger.error(f"LLMRetriever failed: {e}") return [] def _ask_llm_to_retrieve(self, user_query: str) -> List[str]: """Feeds the file hierarchy and user query to the LLM.""" system_prompt = f""" You are a senior software engineer helping to navigate a codebase. Your task is to identify the top {self.top_k} files in the repository that are most likely to contain the answer to the user's query. Here is the file structure of the repository: {self.repo_structure} Rules: 1. Respond ONLY with a list of file paths, one per line. 2. Do not include any explanation or conversational text. 3. Select files that are relevant to: "{user_query}" 4. If the file paths in the structure are relative, return them as they appear in the structure. """ messages = [ SystemMessage(content=system_prompt), HumanMessage(content=f"User Query: {user_query}") ] response = self.llm.invoke(messages) text = response.content.strip() logger.info(f"DEBUG: Raw LLM Response: {text}") # Parse response lines = text.split('\n') selected_files = [] for line in lines: cleaned = line.strip().strip("- ").strip("* ") if cleaned: # Validate if it exists in our known files (fuzzy match if needed) match = self._find_best_match(cleaned) if match: selected_files.append(match) return list(set(selected_files))[:self.top_k] def _find_best_match(self, filename: str) -> Optional[str]: """Finds the closest matching filename from the repo.""" if filename in self.repo_files: return filename # 1. Try exact match on basename for f in self.repo_files: if os.path.basename(f) == filename: return f # 2. Fuzzy match best_match = None min_dist = float('inf') for f in self.repo_files: # We compare with the full path or just the end? # Let's compare with the full path since LLM sees the structure. dist = Levenshtein.distance(filename, f) if dist < min_dist: min_dist = dist best_match = f if min_dist < 20: # Arbitrary threshold return best_match return None