graphforge-openenv / env /ast_parser.py
NagaNithin-V
Deploy GraphForge OpenEnv β€” AST-parsed KG code-editing environment
7952f32
"""AST-based DAG parser and code injection utilities.
parse_source(source, module_name) -> CodeDAG
Parses a Python source string and returns a structured DAG with nodes
(module, function, imported_module) and typed edges (contains, calls, imports).
inject_function_body(source, func_name, new_body) -> str
Replaces the body of func_name in source with new_body, preserving the
def line and any docstring. Used by the environment's step() method.
"""
from __future__ import annotations
import ast
from dataclasses import dataclass, field
# ── DAG data model ────────────────────────────────────────────────────────────
@dataclass
class DAGNode:
name: str
node_type: str # "module" | "function" | "class" | "imported_module"
signature: str = ""
is_stub: bool = False
body_summary: str = ""
@dataclass
class DAGEdge:
edge_type: str # "contains" | "calls" | "imports"
source: str
target: str
@dataclass
class FunctionInfo:
name: str
signature: str
is_stub: bool
start_line: int # 1-indexed
end_line: int # 1-indexed, inclusive
has_docstring: bool
docstring_end_line: int # 1-indexed; == start_line when no docstring
@dataclass
class CodeDAG:
module_name: str
nodes: list[DAGNode] = field(default_factory=list)
edges: list[DAGEdge] = field(default_factory=list)
function_infos: dict[str, FunctionInfo] = field(default_factory=dict)
def callers_of(self, func_name: str) -> list[str]:
return [e.source for e in self.edges if e.edge_type == "calls" and e.target == func_name]
def callees_of(self, func_name: str) -> list[str]:
return [e.target for e in self.edges if e.edge_type == "calls" and e.source == func_name]
def stub_functions(self) -> list[str]:
return [n.name for n in self.nodes if n.node_type == "function" and n.is_stub]
# ── helpers ───────────────────────────────────────────────────────────────────
def _signature(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
parts = []
for arg in node.args.args:
ann = f": {ast.unparse(arg.annotation)}" if arg.annotation else ""
parts.append(f"{arg.arg}{ann}")
ret = f" -> {ast.unparse(node.returns)}" if node.returns else ""
return f"({', '.join(parts)}){ret}"
def _is_stub(node: ast.FunctionDef | ast.AsyncFunctionDef, source: str) -> bool:
func_src = "\n".join(source.splitlines()[node.lineno - 1:node.end_lineno])
if "# STUB" in func_src:
return True
# body that is just "raise NotImplementedError"
stmts = [s for s in node.body
if not (isinstance(s, ast.Expr) and isinstance(s.value, ast.Constant))]
if len(stmts) == 1 and isinstance(stmts[0], ast.Raise):
exc = stmts[0].exc
if isinstance(exc, ast.Name) and exc.id == "NotImplementedError":
return True
if isinstance(exc, ast.Call) and isinstance(exc.func, ast.Name) and exc.func.id == "NotImplementedError":
return True
return False
def _extract_calls(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
calls: set[str] = set()
for child in ast.walk(node):
if isinstance(child, ast.Call):
if isinstance(child.func, ast.Name):
calls.add(child.func.id)
return calls
# ── main parser ───────────────────────────────────────────────────────────────
def parse_source(source: str, module_name: str = "module") -> CodeDAG:
"""Parse Python source into a CodeDAG."""
tree = ast.parse(source)
dag = CodeDAG(module_name=module_name)
dag.nodes.append(DAGNode(name=module_name, node_type="module"))
func_names: set[str] = set()
# imports
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
imp = alias.asname or alias.name
dag.nodes.append(DAGNode(name=imp, node_type="imported_module"))
dag.edges.append(DAGEdge("imports", module_name, imp))
elif isinstance(node, ast.ImportFrom) and node.module:
dag.nodes.append(DAGNode(name=node.module, node_type="imported_module"))
dag.edges.append(DAGEdge("imports", module_name, node.module))
# top-level functions and classes
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
sig = _signature(node)
stub = _is_stub(node, source)
has_doc = (
bool(node.body)
and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, ast.Constant)
)
doc_end = node.body[0].end_lineno if has_doc else node.lineno
dag.nodes.append(DAGNode(
name=node.name,
node_type="function",
signature=sig,
is_stub=stub,
body_summary="STUB β€” needs implementation" if stub else "(implemented)",
))
dag.edges.append(DAGEdge("contains", module_name, node.name))
dag.function_infos[node.name] = FunctionInfo(
name=node.name,
signature=sig,
is_stub=stub,
start_line=node.lineno,
end_line=node.end_lineno,
has_docstring=has_doc,
docstring_end_line=doc_end,
)
func_names.add(node.name)
elif isinstance(node, ast.ClassDef):
dag.nodes.append(DAGNode(name=node.name, node_type="class"))
dag.edges.append(DAGEdge("contains", module_name, node.name))
for item in node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
qname = f"{node.name}.{item.name}"
dag.nodes.append(DAGNode(
name=qname,
node_type="function",
signature=_signature(item),
is_stub=_is_stub(item, source),
))
dag.edges.append(DAGEdge("contains", node.name, qname))
func_names.add(qname)
# call edges (same-module only)
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
for callee in _extract_calls(node):
if callee in func_names and callee != node.name:
dag.edges.append(DAGEdge("calls", node.name, callee))
return dag
# ── code injection ────────────────────────────────────────────────────────────
def inject_function_body(source: str, func_name: str, new_body: str) -> str:
"""Replace the body of func_name in source with new_body.
Preserves the def line and any docstring. new_body should be the raw body
text (with or without indentation β€” we normalise it).
"""
tree = ast.parse(source)
lines = source.splitlines(keepends=True)
for node in tree.body:
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if node.name != func_name:
continue
# Determine where to keep up to (def line + optional docstring)
has_doc = (
bool(node.body)
and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, ast.Constant)
)
keep_until = node.body[0].end_lineno if has_doc else node.lineno
# keep_until is 1-indexed; lines[:keep_until] gives 0..keep_until-1
before = lines[:keep_until]
after = lines[node.end_lineno:] # everything after the function
# Normalise body indent: strip common leading whitespace, then re-add 4 spaces.
raw_lines = new_body.splitlines()
# find minimum indent of non-empty lines
min_indent = min(
(len(l) - len(l.lstrip()) for l in raw_lines if l.strip()),
default=0,
)
body_lines: list[str] = []
for raw_line in raw_lines:
if raw_line.strip():
body_lines.append(" " + raw_line[min_indent:] + "\n")
else:
body_lines.append("\n")
if not body_lines:
body_lines = [" pass\n"]
return "".join(before + body_lines + after)
raise ValueError(f"Function {func_name!r} not found in source")
# ── DAG β†’ text description (for prompts) ─────────────────────────────────────
def dag_to_text(dag: CodeDAG) -> str:
"""Render the DAG as a concise human-readable block for the agent prompt."""
lines: list[str] = [f"## Module: {dag.module_name}", "", "### Nodes"]
for n in dag.nodes:
if n.node_type == "module":
lines.append(f"- [MODULE] {n.name}")
elif n.node_type == "function":
status = "[ STUB ]" if n.is_stub else "[ready ]"
lines.append(f"- [FUNC] {status} {n.name}{n.signature}")
elif n.node_type == "class":
lines.append(f"- [CLASS] {n.name}")
elif n.node_type == "imported_module":
lines.append(f"- [IMPORT] {n.name}")
lines += ["", "### Edges"]
for e in dag.edges:
lines.append(f"- {e.source} --{e.edge_type}--> {e.target}")
return "\n".join(lines)