Spaces:
Sleeping
Sleeping
File size: 9,869 Bytes
7952f32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 | """AST-based DAG parser and code injection utilities.
parse_source(source, module_name) -> CodeDAG
Parses a Python source string and returns a structured DAG with nodes
(module, function, imported_module) and typed edges (contains, calls, imports).
inject_function_body(source, func_name, new_body) -> str
Replaces the body of func_name in source with new_body, preserving the
def line and any docstring. Used by the environment's step() method.
"""
from __future__ import annotations
import ast
from dataclasses import dataclass, field
# ββ DAG data model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@dataclass
class DAGNode:
name: str
node_type: str # "module" | "function" | "class" | "imported_module"
signature: str = ""
is_stub: bool = False
body_summary: str = ""
@dataclass
class DAGEdge:
edge_type: str # "contains" | "calls" | "imports"
source: str
target: str
@dataclass
class FunctionInfo:
name: str
signature: str
is_stub: bool
start_line: int # 1-indexed
end_line: int # 1-indexed, inclusive
has_docstring: bool
docstring_end_line: int # 1-indexed; == start_line when no docstring
@dataclass
class CodeDAG:
module_name: str
nodes: list[DAGNode] = field(default_factory=list)
edges: list[DAGEdge] = field(default_factory=list)
function_infos: dict[str, FunctionInfo] = field(default_factory=dict)
def callers_of(self, func_name: str) -> list[str]:
return [e.source for e in self.edges if e.edge_type == "calls" and e.target == func_name]
def callees_of(self, func_name: str) -> list[str]:
return [e.target for e in self.edges if e.edge_type == "calls" and e.source == func_name]
def stub_functions(self) -> list[str]:
return [n.name for n in self.nodes if n.node_type == "function" and n.is_stub]
# ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _signature(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
parts = []
for arg in node.args.args:
ann = f": {ast.unparse(arg.annotation)}" if arg.annotation else ""
parts.append(f"{arg.arg}{ann}")
ret = f" -> {ast.unparse(node.returns)}" if node.returns else ""
return f"({', '.join(parts)}){ret}"
def _is_stub(node: ast.FunctionDef | ast.AsyncFunctionDef, source: str) -> bool:
func_src = "\n".join(source.splitlines()[node.lineno - 1:node.end_lineno])
if "# STUB" in func_src:
return True
# body that is just "raise NotImplementedError"
stmts = [s for s in node.body
if not (isinstance(s, ast.Expr) and isinstance(s.value, ast.Constant))]
if len(stmts) == 1 and isinstance(stmts[0], ast.Raise):
exc = stmts[0].exc
if isinstance(exc, ast.Name) and exc.id == "NotImplementedError":
return True
if isinstance(exc, ast.Call) and isinstance(exc.func, ast.Name) and exc.func.id == "NotImplementedError":
return True
return False
def _extract_calls(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
calls: set[str] = set()
for child in ast.walk(node):
if isinstance(child, ast.Call):
if isinstance(child.func, ast.Name):
calls.add(child.func.id)
return calls
# ββ main parser βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def parse_source(source: str, module_name: str = "module") -> CodeDAG:
"""Parse Python source into a CodeDAG."""
tree = ast.parse(source)
dag = CodeDAG(module_name=module_name)
dag.nodes.append(DAGNode(name=module_name, node_type="module"))
func_names: set[str] = set()
# imports
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
imp = alias.asname or alias.name
dag.nodes.append(DAGNode(name=imp, node_type="imported_module"))
dag.edges.append(DAGEdge("imports", module_name, imp))
elif isinstance(node, ast.ImportFrom) and node.module:
dag.nodes.append(DAGNode(name=node.module, node_type="imported_module"))
dag.edges.append(DAGEdge("imports", module_name, node.module))
# top-level functions and classes
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
sig = _signature(node)
stub = _is_stub(node, source)
has_doc = (
bool(node.body)
and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, ast.Constant)
)
doc_end = node.body[0].end_lineno if has_doc else node.lineno
dag.nodes.append(DAGNode(
name=node.name,
node_type="function",
signature=sig,
is_stub=stub,
body_summary="STUB β needs implementation" if stub else "(implemented)",
))
dag.edges.append(DAGEdge("contains", module_name, node.name))
dag.function_infos[node.name] = FunctionInfo(
name=node.name,
signature=sig,
is_stub=stub,
start_line=node.lineno,
end_line=node.end_lineno,
has_docstring=has_doc,
docstring_end_line=doc_end,
)
func_names.add(node.name)
elif isinstance(node, ast.ClassDef):
dag.nodes.append(DAGNode(name=node.name, node_type="class"))
dag.edges.append(DAGEdge("contains", module_name, node.name))
for item in node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
qname = f"{node.name}.{item.name}"
dag.nodes.append(DAGNode(
name=qname,
node_type="function",
signature=_signature(item),
is_stub=_is_stub(item, source),
))
dag.edges.append(DAGEdge("contains", node.name, qname))
func_names.add(qname)
# call edges (same-module only)
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
for callee in _extract_calls(node):
if callee in func_names and callee != node.name:
dag.edges.append(DAGEdge("calls", node.name, callee))
return dag
# ββ code injection ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def inject_function_body(source: str, func_name: str, new_body: str) -> str:
"""Replace the body of func_name in source with new_body.
Preserves the def line and any docstring. new_body should be the raw body
text (with or without indentation β we normalise it).
"""
tree = ast.parse(source)
lines = source.splitlines(keepends=True)
for node in tree.body:
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if node.name != func_name:
continue
# Determine where to keep up to (def line + optional docstring)
has_doc = (
bool(node.body)
and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, ast.Constant)
)
keep_until = node.body[0].end_lineno if has_doc else node.lineno
# keep_until is 1-indexed; lines[:keep_until] gives 0..keep_until-1
before = lines[:keep_until]
after = lines[node.end_lineno:] # everything after the function
# Normalise body indent: strip common leading whitespace, then re-add 4 spaces.
raw_lines = new_body.splitlines()
# find minimum indent of non-empty lines
min_indent = min(
(len(l) - len(l.lstrip()) for l in raw_lines if l.strip()),
default=0,
)
body_lines: list[str] = []
for raw_line in raw_lines:
if raw_line.strip():
body_lines.append(" " + raw_line[min_indent:] + "\n")
else:
body_lines.append("\n")
if not body_lines:
body_lines = [" pass\n"]
return "".join(before + body_lines + after)
raise ValueError(f"Function {func_name!r} not found in source")
# ββ DAG β text description (for prompts) βββββββββββββββββββββββββββββββββββββ
def dag_to_text(dag: CodeDAG) -> str:
"""Render the DAG as a concise human-readable block for the agent prompt."""
lines: list[str] = [f"## Module: {dag.module_name}", "", "### Nodes"]
for n in dag.nodes:
if n.node_type == "module":
lines.append(f"- [MODULE] {n.name}")
elif n.node_type == "function":
status = "[ STUB ]" if n.is_stub else "[ready ]"
lines.append(f"- [FUNC] {status} {n.name}{n.signature}")
elif n.node_type == "class":
lines.append(f"- [CLASS] {n.name}")
elif n.node_type == "imported_module":
lines.append(f"- [IMPORT] {n.name}")
lines += ["", "### Edges"]
for e in dag.edges:
lines.append(f"- {e.source} --{e.edge_type}--> {e.target}")
return "\n".join(lines)
|