Spaces:
Running
Running
File size: 6,889 Bytes
5b89d45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | import logging
import os
from typing import Any, Dict, List, Optional, Set
from anytree import Node, RenderTree
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import SystemMessage, HumanMessage
from pydantic import PrivateAttr
import Levenshtein
logger = logging.getLogger(__name__)
class LLMRetriever(BaseRetriever):
"""
Retriever that uses an LLM to select relevant files from the project structure.
Adapted from generic Sage implementation to work with LangChain models.
"""
llm: BaseChatModel
repo_files: List[str]
top_k: int = 5
repo_structure: str = ""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Use object.__setattr__ to avoid pydantic validation errors if frozen
# But since we made it a field, we can just set it OR pass it in kwargs if calculated before.
# Better: calculate it here and set it.
structure = self._build_repo_structure(self.repo_files)
self.repo_structure = structure
def _build_repo_structure(self, files: List[str]) -> str:
"""Builds a visual tree structure of the repository."""
# Build tree
root = Node("root")
nodes = {"": root}
for file_path in files:
parts = file_path.strip("/").split("/")
current_path = ""
parent = root
for part in parts:
current_path = f"{current_path}/{part}" if current_path else part
if current_path not in nodes:
nodes[current_path] = Node(part, parent=parent)
parent = nodes[current_path]
# Render tree
render = ""
for pre, _, node in RenderTree(root):
if node.name == "root": continue
# Simplify characters for token efficiency
line = f"{pre}{node.name}"
line = line.replace("└", " ").replace("├", " ").replace("│", " ").replace("─", " ")
render += line + "\n"
return render
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
"""Retrieve relevant documents for a given query."""
try:
logger.info("LLMRetriever: Asking LLM to select files...")
filenames = self._ask_llm_to_retrieve(query)
logger.info(f"LLMRetriever: Selected {len(filenames)} files: {filenames}")
documents = []
for filename in filenames:
# We expect the caller to handle reading the actual content if needed,
# or we return a Document with just metadata if we don't have access to the file system here.
# Ideally, we should have access to read the file.
# Let's assume we can read if it is a local path (which it should be in this app).
# Check if we can find the absolute path?
# The repo_files passed in might be relative paths or absolute.
# We will assume they are paths we can open.
try:
# If repo_files are absolute, great. If relative, we might need a base_dir.
# For now, let's assume the passed repo_files are valid paths to read.
if os.path.exists(filename):
with open(filename, "r", errors='ignore') as f:
content = f.read()
documents.append(Document(
page_content=content,
metadata={"file_path": filename, "source": "llm_retriever"}
))
else:
documents.append(Document(
page_content="",
metadata={"file_path": filename, "source": "llm_retriever", "error": "File not found"}
))
except Exception as e:
logger.warning(f"Failed to read file {filename}: {e}")
return documents
except Exception as e:
logger.error(f"LLMRetriever failed: {e}")
return []
def _ask_llm_to_retrieve(self, user_query: str) -> List[str]:
"""Feeds the file hierarchy and user query to the LLM."""
system_prompt = f"""
You are a senior software engineer helping to navigate a codebase.
Your task is to identify the top {self.top_k} files in the repository that are most likely to contain the answer to the user's query.
Here is the file structure of the repository:
{self.repo_structure}
Rules:
1. Respond ONLY with a list of file paths, one per line.
2. Do not include any explanation or conversational text.
3. Select files that are relevant to: "{user_query}"
4. If the file paths in the structure are relative, return them as they appear in the structure.
"""
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=f"User Query: {user_query}")
]
response = self.llm.invoke(messages)
text = response.content.strip()
logger.info(f"DEBUG: Raw LLM Response: {text}")
# Parse response
lines = text.split('\n')
selected_files = []
for line in lines:
cleaned = line.strip().strip("- ").strip("* ")
if cleaned:
# Validate if it exists in our known files (fuzzy match if needed)
match = self._find_best_match(cleaned)
if match:
selected_files.append(match)
return list(set(selected_files))[:self.top_k]
def _find_best_match(self, filename: str) -> Optional[str]:
"""Finds the closest matching filename from the repo."""
if filename in self.repo_files:
return filename
# 1. Try exact match on basename
for f in self.repo_files:
if os.path.basename(f) == filename:
return f
# 2. Fuzzy match
best_match = None
min_dist = float('inf')
for f in self.repo_files:
# We compare with the full path or just the end?
# Let's compare with the full path since LLM sees the structure.
dist = Levenshtein.distance(filename, f)
if dist < min_dist:
min_dist = dist
best_match = f
if min_dist < 20: # Arbitrary threshold
return best_match
return None
|