Spaces:
Running
Running
File size: 5,286 Bytes
5b89d45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | import os
import networkx as nx
import logging
from typing import List, Optional, Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
logger = logging.getLogger(__name__)
class GraphEnhancedRetriever(BaseRetriever):
"""Wraps a base retriever and augments results using an AST knowledge graph."""
base_retriever: BaseRetriever
graph: Optional[Any] = None
repo_dir: str
def __init__(self, base_retriever: BaseRetriever, repo_dir: str, **kwargs):
# Initialize Pydantic fields
super().__init__(base_retriever=base_retriever, repo_dir=repo_dir, **kwargs)
self.graph = self._load_graph()
def _load_graph(self):
graph_path = os.path.join(self.repo_dir, "ast_graph.graphml")
if os.path.exists(graph_path):
try:
logger.info(f"Loading AST Graph from {graph_path}")
return nx.read_graphml(graph_path)
except Exception as e:
logger.error(f"Failed to load AST graph: {e}")
else:
logger.warning(f"No AST graph found at {graph_path}")
return None
def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
# 1. Standard Retrieval
logger.info(f"GraphEnhancedRetriever: Querying base retriever with: '{query}'")
docs = self.base_retriever.invoke(query)
logger.info(f"GraphEnhancedRetriever: Base retriever returned {len(docs)} documents")
if not self.graph:
logger.warning("No AST graph available for enhancement")
return docs
# 2. Graph Expansion
augmented_docs = list(docs)
seen_files = {d.metadata.get("file_path") for d in docs}
# We also want to see what files are already in the docs to avoid duplicating content
# But here we are looking for RELATED files that might not be in the vector search results.
for doc in docs:
file_path = doc.metadata.get("file_path")
if not file_path: continue
# Normalize path if needed (relative vs absolute)
# The graph was built with paths relative to extracting location or absolute?
# We need to ensure consistency.
# In ingestor we use: rel_path for source, but file_path for absolute.
# In ast_analysis we used file_path passed to add_file.
# We need to verify how we call add_file in app.py.
# Let's try to find the node in the graph
target_node = None
if file_path in self.graph:
target_node = file_path
else:
# Try checking if just filename match
# Or try absolute path match (depends on how we built the graph)
pass
if target_node and target_node in self.graph:
neighbors = list(self.graph.neighbors(target_node))
for neighbor in neighbors:
# Neighbor could be a file or a symbol (file::symbol)
if "::" in neighbor:
neighbor_file = neighbor.split("::")[0]
else:
neighbor_file = neighbor
# Skip if we've already seen this file
if neighbor_file in seen_files:
continue
# Check if file exists (handle both relative and absolute paths)
if os.path.exists(neighbor_file):
try:
# Limit expansion to small files to avoid context overflow
if os.path.getsize(neighbor_file) < 20000: # 20KB limit
with open(neighbor_file, "r", errors='ignore') as f:
content = f.read()
# Get relationship type from edge
edge_data = self.graph.get_edge_data(target_node, neighbor, {})
relation = edge_data.get("relation", "related") if edge_data else "related"
new_doc = Document(
page_content=f"--- Graph Context ({relation} from {os.path.basename(file_path)}) ---\n{content}",
metadata={
"file_path": neighbor_file,
"source": "ast_graph",
"relation": relation,
"related_to": file_path
}
)
augmented_docs.append(new_doc)
seen_files.add(neighbor_file)
logger.debug(f"Added graph-related file: {neighbor_file} (relation: {relation})")
except Exception as e:
logger.warning(f"Failed to add graph-related file {neighbor_file}: {e}")
return augmented_docs
|