Spaces:
Running
Running
File size: 7,238 Bytes
5b89d45 5ceb659 986715f 5ceb659 5b89d45 5ceb659 5b89d45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import os
import networkx as nx
import logging
from typing import List, Optional, Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
logger = logging.getLogger(__name__)
class GraphEnhancedRetriever(BaseRetriever):
"""Wraps a base retriever and augments results using an AST knowledge graph."""
base_retriever: BaseRetriever
graph: Optional[Any] = None
repo_dir: str
def __init__(self, base_retriever: BaseRetriever, repo_dir: str, **kwargs):
# Initialize Pydantic fields
super().__init__(base_retriever=base_retriever, repo_dir=repo_dir, **kwargs)
self.graph = self._load_graph()
def _load_graph(self):
graph_path = os.path.join(self.repo_dir, "ast_graph.graphml")
if os.path.exists(graph_path):
try:
logger.info(f"Loading AST Graph from {graph_path}")
return nx.read_graphml(graph_path)
except Exception as e:
logger.error(f"Failed to load AST graph: {e}")
else:
logger.warning(f"No AST graph found at {graph_path}")
return None
def _rerank_by_file_type(self, docs: List[Document]) -> List[Document]:
"""Rerank documents to prioritize source code over config/text files."""
# Priority weights: higher = more important
def get_priority(doc: Document) -> int:
file_path = doc.metadata.get("file_path", "").lower()
# Highest priority: Main entry points
main_files = ["main.py", "app.py", "index.js", "index.ts", "server.py", "api.py"]
if any(file_path.endswith(f) for f in main_files):
return 100
# High priority: Source code files
code_extensions = [".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".go", ".rs", ".cpp", ".c"]
if any(file_path.endswith(ext) for ext in code_extensions):
return 80
# Medium priority: Config files (still useful)
config_extensions = [".json", ".yaml", ".yml", ".toml"]
if any(file_path.endswith(ext) for ext in config_extensions):
return 50
# Low priority: Text/doc files (often too generic)
# EXCEPTION: README files are critical for context
if "readme" in file_path.lower():
return 90
text_extensions = [".txt", ".md", ".rst"]
if any(file_path.endswith(ext) for ext in text_extensions):
return 30
# Default
return 40
# Sort by priority (descending), keeping relative order for same priority
ranked = sorted(docs, key=lambda d: get_priority(d), reverse=True)
logger.info(f"Reranked docs: top files are {[d.metadata.get('file_path', '?').split('/')[-1] for d in ranked[:3]]}")
return ranked
def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
# 1. Standard Retrieval
logger.info(f"GraphEnhancedRetriever: Querying base retriever with: '{query}'")
docs = self.base_retriever.invoke(query)
logger.info(f"GraphEnhancedRetriever: Base retriever returned {len(docs)} documents")
# 2. Rerank: Prioritize source code over config/text files
docs = self._rerank_by_file_type(docs)
if not self.graph:
logger.warning("No AST graph available for enhancement")
return docs
# 2. Graph Expansion
augmented_docs = list(docs)
seen_files = {d.metadata.get("file_path") for d in docs}
# We also want to see what files are already in the docs to avoid duplicating content
# But here we are looking for RELATED files that might not be in the vector search results.
for doc in docs:
file_path = doc.metadata.get("file_path")
if not file_path: continue
# Normalize path if needed (relative vs absolute)
# The graph was built with paths relative to extracting location or absolute?
# We need to ensure consistency.
# In ingestor we use: rel_path for source, but file_path for absolute.
# In ast_analysis we used file_path passed to add_file.
# We need to verify how we call add_file in app.py.
# Let's try to find the node in the graph
target_node = None
if file_path in self.graph:
target_node = file_path
else:
# Try checking if just filename match
# Or try absolute path match (depends on how we built the graph)
pass
if target_node and target_node in self.graph:
neighbors = list(self.graph.neighbors(target_node))
for neighbor in neighbors:
# Neighbor could be a file or a symbol (file::symbol)
if "::" in neighbor:
neighbor_file = neighbor.split("::")[0]
else:
neighbor_file = neighbor
# Skip if we've already seen this file
if neighbor_file in seen_files:
continue
# Check if file exists (handle both relative and absolute paths)
if os.path.exists(neighbor_file):
try:
# Limit expansion to small files to avoid context overflow
if os.path.getsize(neighbor_file) < 20000: # 20KB limit
with open(neighbor_file, "r", errors='ignore') as f:
content = f.read()
# Get relationship type from edge
edge_data = self.graph.get_edge_data(target_node, neighbor, {})
relation = edge_data.get("relation", "related") if edge_data else "related"
new_doc = Document(
page_content=f"--- Graph Context ({relation} from {os.path.basename(file_path)}) ---\n{content}",
metadata={
"file_path": neighbor_file,
"source": "ast_graph",
"relation": relation,
"related_to": file_path
}
)
augmented_docs.append(new_doc)
seen_files.add(neighbor_file)
logger.debug(f"Added graph-related file: {neighbor_file} (relation: {relation})")
except Exception as e:
logger.warning(f"Failed to add graph-related file {neighbor_file}: {e}")
return augmented_docs
|