Spaces:
Sleeping
Sleeping
File size: 10,314 Bytes
7952f32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 | """In-memory Knowledge Graph for a Python repository.
Mirrors the structure of a Neo4j property graph but lives in RAM:
Nodes
-----
repo β the repository root
package β a directory containing __init__.py
module β a .py file
class β a class definition
function β a top-level or nested function / async function
method β a method inside a class
Edges (directed)
-----------------
contains β parent β child (repoβpackage, packageβmodule, moduleβclass, β¦)
calls β function/method β function/method (same-file same-package)
imports β module β module (from x import y / import x)
inherits β class β class
Each node stores the actual source lines so the agent can read/edit them.
"""
from __future__ import annotations
import textwrap
from dataclasses import dataclass, field
from typing import Iterable
# ββ node & edge βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@dataclass
class KGNode:
node_id: str # unique key, e.g. "function:validators.py:validate_title"
node_type: str # module | class | function | method | package | repo
name: str # short identifier
file_path: str # relative path from repo root (empty for repo/package)
line_start: int = 0
line_end: int = 0
source: str = "" # full source text of this node (incl. def line)
docstring: str = ""
metadata: dict = field(default_factory=dict)
def brief(self) -> str:
"""One-line summary for graph overviews."""
loc = f" [{self.file_path}:{self.line_start}]" if self.file_path else ""
return f"[{self.node_type.upper():<8}] {self.node_id}{loc}"
@dataclass
class KGEdge:
edge_type: str # contains | calls | imports | inherits
source_id: str
target_id: str
# ββ knowledge graph βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class KnowledgeGraph:
"""Property graph for a repository.
Supports rich queries used by the agent and reward checker.
"""
def __init__(self, repo_path: str) -> None:
self.repo_path = repo_path
self._nodes: dict[str, KGNode] = {}
self._edges: list[KGEdge] = []
# ββ mutation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def add_node(self, node: KGNode) -> None:
self._nodes[node.node_id] = node
def add_edge(self, edge: KGEdge) -> None:
self._edges.append(edge)
def update_node_source(self, node_id: str, new_source: str) -> None:
"""Replace a node's source and recount lines."""
node = self._nodes[node_id]
node.source = new_source
lines = new_source.splitlines()
node.line_end = node.line_start + len(lines) - 1
def insert_node(
self,
parent_id: str,
new_node: KGNode,
) -> None:
"""Add new_node to the graph and wire a contains edge from parent."""
self._nodes[new_node.node_id] = new_node
self._edges.append(KGEdge("contains", parent_id, new_node.node_id))
def remove_node(self, node_id: str) -> None:
self._nodes.pop(node_id, None)
self._edges = [e for e in self._edges
if e.source_id != node_id and e.target_id != node_id]
# ββ queries βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_node(self, node_id: str) -> KGNode | None:
return self._nodes.get(node_id)
def all_nodes(self, node_type: str | None = None) -> list[KGNode]:
nodes = list(self._nodes.values())
if node_type:
nodes = [n for n in nodes if n.node_type == node_type]
return nodes
def children_of(self, node_id: str) -> list[KGNode]:
child_ids = {e.target_id for e in self._edges
if e.source_id == node_id and e.edge_type == "contains"}
return [self._nodes[cid] for cid in child_ids if cid in self._nodes]
def parent_of(self, node_id: str) -> KGNode | None:
for e in self._edges:
if e.target_id == node_id and e.edge_type == "contains":
return self._nodes.get(e.source_id)
return None
def callers_of(self, node_id: str) -> list[KGNode]:
caller_ids = {e.source_id for e in self._edges
if e.target_id == node_id and e.edge_type == "calls"}
return [self._nodes[cid] for cid in caller_ids if cid in self._nodes]
def callees_of(self, node_id: str) -> list[KGNode]:
callee_ids = {e.target_id for e in self._edges
if e.source_id == node_id and e.edge_type == "calls"}
return [self._nodes[cid] for cid in callee_ids if cid in self._nodes]
def imports_of(self, module_id: str) -> list[KGNode]:
imp_ids = {e.target_id for e in self._edges
if e.source_id == module_id and e.edge_type == "imports"}
return [self._nodes[i] for i in imp_ids if i in self._nodes]
def search(self, keywords: str, node_type: str | None = None) -> list[KGNode]:
"""Fuzzy keyword search over node names, docstrings, and source."""
kws = keywords.lower().split()
results: list[KGNode] = []
for node in self._nodes.values():
if node_type and node.node_type != node_type:
continue
haystack = f"{node.name} {node.docstring} {node.source}".lower()
if all(kw in haystack for kw in kws):
results.append(node)
return results
def subgraph(self, root_id: str, depth: int = 2) -> list[KGNode]:
"""BFS from root_id up to depth hops; returns all encountered nodes."""
visited: set[str] = set()
frontier = {root_id}
for _ in range(depth):
next_frontier: set[str] = set()
for nid in frontier:
if nid in visited:
continue
visited.add(nid)
for e in self._edges:
if e.source_id == nid and e.target_id not in visited:
next_frontier.add(e.target_id)
frontier = next_frontier
visited.update(frontier)
return [self._nodes[nid] for nid in visited if nid in self._nodes]
# ββ text representations ββββββββββββββββββββββββββββββββββββββββββββββββββ
def overview(self, max_chars: int = 3000) -> str:
"""Compact multi-line overview of the repo graph, capped to avoid LLM context overflow."""
lines: list[str] = [f"## Repository: {self.repo_path}", ""]
modules = self.all_nodes("module")
all_fns = self.all_nodes("function")
all_cls = self.all_nodes("class")
lines.append(f" {len(modules)} modules Β· {len(all_fns)} functions Β· {len(all_cls)} classes")
lines.append("")
for mod in sorted(modules, key=lambda n: n.file_path):
children = self.children_of(mod.node_id)
funcs = [c for c in children if c.node_type in ("function", "method")]
classes = [c for c in children if c.node_type == "class"]
summary = []
if classes:
summary.append(f"{len(classes)} class{'es' if len(classes)>1 else ''}")
if funcs:
summary.append(f"{len(funcs)} fn{'s' if len(funcs)>1 else ''}")
lines.append(f" [{mod.file_path}] ({', '.join(summary) or 'empty'})")
for cls in sorted(classes, key=lambda n: n.name):
methods = [c for c in self.children_of(cls.node_id) if c.node_type == "method"]
mnames = ", ".join(m.name for m in sorted(methods, key=lambda n: n.line_start))
lines.append(f" class {cls.name} β {mnames or '(no methods)'}")
lines.append(f" node_id: {cls.node_id}")
for fn in sorted(funcs, key=lambda n: n.line_start):
lines.append(f" def {fn.name}{fn.metadata.get('signature', '')}")
lines.append(f" node_id: {fn.node_id}")
# Stop expanding if we are already near the character cap
current = "\n".join(lines)
if len(current) > max_chars:
remaining = len(modules) - (modules.index(mod) + 1)
if remaining:
lines.append(f"\n ... [{remaining} more modules not shown β use query() to explore]")
break
return "\n".join(lines)
def node_detail(self, node_id: str) -> str:
"""Full inspection view of a single node."""
node = self._nodes.get(node_id)
if node is None:
return f"[ERROR] node_id {node_id!r} not found in graph."
lines = [
f"## Node: {node.node_id}",
f"type : {node.node_type}",
f"file : {node.file_path} (lines {node.line_start}β{node.line_end})",
]
if node.docstring:
lines.append(f"docstring: {node.docstring[:120]}")
callers = self.callers_of(node_id)
callees = self.callees_of(node_id)
if callers:
lines.append("called by: " + ", ".join(n.name for n in callers))
if callees:
lines.append("calls : " + ", ".join(n.name for n in callees))
children = self.children_of(node_id)
if children:
lines.append("contains : " + ", ".join(c.name for c in children))
lines += ["", "### Source", "```python", node.source or "(no source)", "```"]
return "\n".join(lines)
def snapshot(self) -> "KnowledgeGraph":
"""Deep copy β used to preserve state before mutations."""
import copy
return copy.deepcopy(self)
|