Spaces:
Sleeping
Sleeping
| """In-memory Knowledge Graph for a Python repository. | |
| Mirrors the structure of a Neo4j property graph but lives in RAM: | |
| Nodes | |
| ----- | |
| repo β the repository root | |
| package β a directory containing __init__.py | |
| module β a .py file | |
| class β a class definition | |
| function β a top-level or nested function / async function | |
| method β a method inside a class | |
| Edges (directed) | |
| ----------------- | |
| contains β parent β child (repoβpackage, packageβmodule, moduleβclass, β¦) | |
| calls β function/method β function/method (same-file same-package) | |
| imports β module β module (from x import y / import x) | |
| inherits β class β class | |
| Each node stores the actual source lines so the agent can read/edit them. | |
| """ | |
| from __future__ import annotations | |
| import textwrap | |
| from dataclasses import dataclass, field | |
| from typing import Iterable | |
| # ββ node & edge βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class KGNode: | |
| node_id: str # unique key, e.g. "function:validators.py:validate_title" | |
| node_type: str # module | class | function | method | package | repo | |
| name: str # short identifier | |
| file_path: str # relative path from repo root (empty for repo/package) | |
| line_start: int = 0 | |
| line_end: int = 0 | |
| source: str = "" # full source text of this node (incl. def line) | |
| docstring: str = "" | |
| metadata: dict = field(default_factory=dict) | |
| def brief(self) -> str: | |
| """One-line summary for graph overviews.""" | |
| loc = f" [{self.file_path}:{self.line_start}]" if self.file_path else "" | |
| return f"[{self.node_type.upper():<8}] {self.node_id}{loc}" | |
| class KGEdge: | |
| edge_type: str # contains | calls | imports | inherits | |
| source_id: str | |
| target_id: str | |
| # ββ knowledge graph βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class KnowledgeGraph: | |
| """Property graph for a repository. | |
| Supports rich queries used by the agent and reward checker. | |
| """ | |
| def __init__(self, repo_path: str) -> None: | |
| self.repo_path = repo_path | |
| self._nodes: dict[str, KGNode] = {} | |
| self._edges: list[KGEdge] = [] | |
| # ββ mutation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def add_node(self, node: KGNode) -> None: | |
| self._nodes[node.node_id] = node | |
| def add_edge(self, edge: KGEdge) -> None: | |
| self._edges.append(edge) | |
| def update_node_source(self, node_id: str, new_source: str) -> None: | |
| """Replace a node's source and recount lines.""" | |
| node = self._nodes[node_id] | |
| node.source = new_source | |
| lines = new_source.splitlines() | |
| node.line_end = node.line_start + len(lines) - 1 | |
| def insert_node( | |
| self, | |
| parent_id: str, | |
| new_node: KGNode, | |
| ) -> None: | |
| """Add new_node to the graph and wire a contains edge from parent.""" | |
| self._nodes[new_node.node_id] = new_node | |
| self._edges.append(KGEdge("contains", parent_id, new_node.node_id)) | |
| def remove_node(self, node_id: str) -> None: | |
| self._nodes.pop(node_id, None) | |
| self._edges = [e for e in self._edges | |
| if e.source_id != node_id and e.target_id != node_id] | |
| # ββ queries βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_node(self, node_id: str) -> KGNode | None: | |
| return self._nodes.get(node_id) | |
| def all_nodes(self, node_type: str | None = None) -> list[KGNode]: | |
| nodes = list(self._nodes.values()) | |
| if node_type: | |
| nodes = [n for n in nodes if n.node_type == node_type] | |
| return nodes | |
| def children_of(self, node_id: str) -> list[KGNode]: | |
| child_ids = {e.target_id for e in self._edges | |
| if e.source_id == node_id and e.edge_type == "contains"} | |
| return [self._nodes[cid] for cid in child_ids if cid in self._nodes] | |
| def parent_of(self, node_id: str) -> KGNode | None: | |
| for e in self._edges: | |
| if e.target_id == node_id and e.edge_type == "contains": | |
| return self._nodes.get(e.source_id) | |
| return None | |
| def callers_of(self, node_id: str) -> list[KGNode]: | |
| caller_ids = {e.source_id for e in self._edges | |
| if e.target_id == node_id and e.edge_type == "calls"} | |
| return [self._nodes[cid] for cid in caller_ids if cid in self._nodes] | |
| def callees_of(self, node_id: str) -> list[KGNode]: | |
| callee_ids = {e.target_id for e in self._edges | |
| if e.source_id == node_id and e.edge_type == "calls"} | |
| return [self._nodes[cid] for cid in callee_ids if cid in self._nodes] | |
| def imports_of(self, module_id: str) -> list[KGNode]: | |
| imp_ids = {e.target_id for e in self._edges | |
| if e.source_id == module_id and e.edge_type == "imports"} | |
| return [self._nodes[i] for i in imp_ids if i in self._nodes] | |
| def search(self, keywords: str, node_type: str | None = None) -> list[KGNode]: | |
| """Fuzzy keyword search over node names, docstrings, and source.""" | |
| kws = keywords.lower().split() | |
| results: list[KGNode] = [] | |
| for node in self._nodes.values(): | |
| if node_type and node.node_type != node_type: | |
| continue | |
| haystack = f"{node.name} {node.docstring} {node.source}".lower() | |
| if all(kw in haystack for kw in kws): | |
| results.append(node) | |
| return results | |
| def subgraph(self, root_id: str, depth: int = 2) -> list[KGNode]: | |
| """BFS from root_id up to depth hops; returns all encountered nodes.""" | |
| visited: set[str] = set() | |
| frontier = {root_id} | |
| for _ in range(depth): | |
| next_frontier: set[str] = set() | |
| for nid in frontier: | |
| if nid in visited: | |
| continue | |
| visited.add(nid) | |
| for e in self._edges: | |
| if e.source_id == nid and e.target_id not in visited: | |
| next_frontier.add(e.target_id) | |
| frontier = next_frontier | |
| visited.update(frontier) | |
| return [self._nodes[nid] for nid in visited if nid in self._nodes] | |
| # ββ text representations ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def overview(self, max_chars: int = 3000) -> str: | |
| """Compact multi-line overview of the repo graph, capped to avoid LLM context overflow.""" | |
| lines: list[str] = [f"## Repository: {self.repo_path}", ""] | |
| modules = self.all_nodes("module") | |
| all_fns = self.all_nodes("function") | |
| all_cls = self.all_nodes("class") | |
| lines.append(f" {len(modules)} modules Β· {len(all_fns)} functions Β· {len(all_cls)} classes") | |
| lines.append("") | |
| for mod in sorted(modules, key=lambda n: n.file_path): | |
| children = self.children_of(mod.node_id) | |
| funcs = [c for c in children if c.node_type in ("function", "method")] | |
| classes = [c for c in children if c.node_type == "class"] | |
| summary = [] | |
| if classes: | |
| summary.append(f"{len(classes)} class{'es' if len(classes)>1 else ''}") | |
| if funcs: | |
| summary.append(f"{len(funcs)} fn{'s' if len(funcs)>1 else ''}") | |
| lines.append(f" [{mod.file_path}] ({', '.join(summary) or 'empty'})") | |
| for cls in sorted(classes, key=lambda n: n.name): | |
| methods = [c for c in self.children_of(cls.node_id) if c.node_type == "method"] | |
| mnames = ", ".join(m.name for m in sorted(methods, key=lambda n: n.line_start)) | |
| lines.append(f" class {cls.name} β {mnames or '(no methods)'}") | |
| lines.append(f" node_id: {cls.node_id}") | |
| for fn in sorted(funcs, key=lambda n: n.line_start): | |
| lines.append(f" def {fn.name}{fn.metadata.get('signature', '')}") | |
| lines.append(f" node_id: {fn.node_id}") | |
| # Stop expanding if we are already near the character cap | |
| current = "\n".join(lines) | |
| if len(current) > max_chars: | |
| remaining = len(modules) - (modules.index(mod) + 1) | |
| if remaining: | |
| lines.append(f"\n ... [{remaining} more modules not shown β use query() to explore]") | |
| break | |
| return "\n".join(lines) | |
| def node_detail(self, node_id: str) -> str: | |
| """Full inspection view of a single node.""" | |
| node = self._nodes.get(node_id) | |
| if node is None: | |
| return f"[ERROR] node_id {node_id!r} not found in graph." | |
| lines = [ | |
| f"## Node: {node.node_id}", | |
| f"type : {node.node_type}", | |
| f"file : {node.file_path} (lines {node.line_start}β{node.line_end})", | |
| ] | |
| if node.docstring: | |
| lines.append(f"docstring: {node.docstring[:120]}") | |
| callers = self.callers_of(node_id) | |
| callees = self.callees_of(node_id) | |
| if callers: | |
| lines.append("called by: " + ", ".join(n.name for n in callers)) | |
| if callees: | |
| lines.append("calls : " + ", ".join(n.name for n in callees)) | |
| children = self.children_of(node_id) | |
| if children: | |
| lines.append("contains : " + ", ".join(c.name for c in children)) | |
| lines += ["", "### Source", "```python", node.source or "(no source)", "```"] | |
| return "\n".join(lines) | |
| def snapshot(self) -> "KnowledgeGraph": | |
| """Deep copy β used to preserve state before mutations.""" | |
| import copy | |
| return copy.deepcopy(self) | |