Spaces:
Sleeping
Sleeping
| """AST-based DAG parser and code injection utilities. | |
| parse_source(source, module_name) -> CodeDAG | |
| Parses a Python source string and returns a structured DAG with nodes | |
| (module, function, imported_module) and typed edges (contains, calls, imports). | |
| inject_function_body(source, func_name, new_body) -> str | |
| Replaces the body of func_name in source with new_body, preserving the | |
| def line and any docstring. Used by the environment's step() method. | |
| """ | |
| from __future__ import annotations | |
| import ast | |
| from dataclasses import dataclass, field | |
| # ββ DAG data model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DAGNode: | |
| name: str | |
| node_type: str # "module" | "function" | "class" | "imported_module" | |
| signature: str = "" | |
| is_stub: bool = False | |
| body_summary: str = "" | |
| class DAGEdge: | |
| edge_type: str # "contains" | "calls" | "imports" | |
| source: str | |
| target: str | |
| class FunctionInfo: | |
| name: str | |
| signature: str | |
| is_stub: bool | |
| start_line: int # 1-indexed | |
| end_line: int # 1-indexed, inclusive | |
| has_docstring: bool | |
| docstring_end_line: int # 1-indexed; == start_line when no docstring | |
| class CodeDAG: | |
| module_name: str | |
| nodes: list[DAGNode] = field(default_factory=list) | |
| edges: list[DAGEdge] = field(default_factory=list) | |
| function_infos: dict[str, FunctionInfo] = field(default_factory=dict) | |
| def callers_of(self, func_name: str) -> list[str]: | |
| return [e.source for e in self.edges if e.edge_type == "calls" and e.target == func_name] | |
| def callees_of(self, func_name: str) -> list[str]: | |
| return [e.target for e in self.edges if e.edge_type == "calls" and e.source == func_name] | |
| def stub_functions(self) -> list[str]: | |
| return [n.name for n in self.nodes if n.node_type == "function" and n.is_stub] | |
| # ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _signature(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str: | |
| parts = [] | |
| for arg in node.args.args: | |
| ann = f": {ast.unparse(arg.annotation)}" if arg.annotation else "" | |
| parts.append(f"{arg.arg}{ann}") | |
| ret = f" -> {ast.unparse(node.returns)}" if node.returns else "" | |
| return f"({', '.join(parts)}){ret}" | |
| def _is_stub(node: ast.FunctionDef | ast.AsyncFunctionDef, source: str) -> bool: | |
| func_src = "\n".join(source.splitlines()[node.lineno - 1:node.end_lineno]) | |
| if "# STUB" in func_src: | |
| return True | |
| # body that is just "raise NotImplementedError" | |
| stmts = [s for s in node.body | |
| if not (isinstance(s, ast.Expr) and isinstance(s.value, ast.Constant))] | |
| if len(stmts) == 1 and isinstance(stmts[0], ast.Raise): | |
| exc = stmts[0].exc | |
| if isinstance(exc, ast.Name) and exc.id == "NotImplementedError": | |
| return True | |
| if isinstance(exc, ast.Call) and isinstance(exc.func, ast.Name) and exc.func.id == "NotImplementedError": | |
| return True | |
| return False | |
| def _extract_calls(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]: | |
| calls: set[str] = set() | |
| for child in ast.walk(node): | |
| if isinstance(child, ast.Call): | |
| if isinstance(child.func, ast.Name): | |
| calls.add(child.func.id) | |
| return calls | |
| # ββ main parser βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_source(source: str, module_name: str = "module") -> CodeDAG: | |
| """Parse Python source into a CodeDAG.""" | |
| tree = ast.parse(source) | |
| dag = CodeDAG(module_name=module_name) | |
| dag.nodes.append(DAGNode(name=module_name, node_type="module")) | |
| func_names: set[str] = set() | |
| # imports | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.Import): | |
| for alias in node.names: | |
| imp = alias.asname or alias.name | |
| dag.nodes.append(DAGNode(name=imp, node_type="imported_module")) | |
| dag.edges.append(DAGEdge("imports", module_name, imp)) | |
| elif isinstance(node, ast.ImportFrom) and node.module: | |
| dag.nodes.append(DAGNode(name=node.module, node_type="imported_module")) | |
| dag.edges.append(DAGEdge("imports", module_name, node.module)) | |
| # top-level functions and classes | |
| for node in tree.body: | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| sig = _signature(node) | |
| stub = _is_stub(node, source) | |
| has_doc = ( | |
| bool(node.body) | |
| and isinstance(node.body[0], ast.Expr) | |
| and isinstance(node.body[0].value, ast.Constant) | |
| ) | |
| doc_end = node.body[0].end_lineno if has_doc else node.lineno | |
| dag.nodes.append(DAGNode( | |
| name=node.name, | |
| node_type="function", | |
| signature=sig, | |
| is_stub=stub, | |
| body_summary="STUB β needs implementation" if stub else "(implemented)", | |
| )) | |
| dag.edges.append(DAGEdge("contains", module_name, node.name)) | |
| dag.function_infos[node.name] = FunctionInfo( | |
| name=node.name, | |
| signature=sig, | |
| is_stub=stub, | |
| start_line=node.lineno, | |
| end_line=node.end_lineno, | |
| has_docstring=has_doc, | |
| docstring_end_line=doc_end, | |
| ) | |
| func_names.add(node.name) | |
| elif isinstance(node, ast.ClassDef): | |
| dag.nodes.append(DAGNode(name=node.name, node_type="class")) | |
| dag.edges.append(DAGEdge("contains", module_name, node.name)) | |
| for item in node.body: | |
| if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| qname = f"{node.name}.{item.name}" | |
| dag.nodes.append(DAGNode( | |
| name=qname, | |
| node_type="function", | |
| signature=_signature(item), | |
| is_stub=_is_stub(item, source), | |
| )) | |
| dag.edges.append(DAGEdge("contains", node.name, qname)) | |
| func_names.add(qname) | |
| # call edges (same-module only) | |
| for node in tree.body: | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| for callee in _extract_calls(node): | |
| if callee in func_names and callee != node.name: | |
| dag.edges.append(DAGEdge("calls", node.name, callee)) | |
| return dag | |
| # ββ code injection ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def inject_function_body(source: str, func_name: str, new_body: str) -> str: | |
| """Replace the body of func_name in source with new_body. | |
| Preserves the def line and any docstring. new_body should be the raw body | |
| text (with or without indentation β we normalise it). | |
| """ | |
| tree = ast.parse(source) | |
| lines = source.splitlines(keepends=True) | |
| for node in tree.body: | |
| if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| continue | |
| if node.name != func_name: | |
| continue | |
| # Determine where to keep up to (def line + optional docstring) | |
| has_doc = ( | |
| bool(node.body) | |
| and isinstance(node.body[0], ast.Expr) | |
| and isinstance(node.body[0].value, ast.Constant) | |
| ) | |
| keep_until = node.body[0].end_lineno if has_doc else node.lineno | |
| # keep_until is 1-indexed; lines[:keep_until] gives 0..keep_until-1 | |
| before = lines[:keep_until] | |
| after = lines[node.end_lineno:] # everything after the function | |
| # Normalise body indent: strip common leading whitespace, then re-add 4 spaces. | |
| raw_lines = new_body.splitlines() | |
| # find minimum indent of non-empty lines | |
| min_indent = min( | |
| (len(l) - len(l.lstrip()) for l in raw_lines if l.strip()), | |
| default=0, | |
| ) | |
| body_lines: list[str] = [] | |
| for raw_line in raw_lines: | |
| if raw_line.strip(): | |
| body_lines.append(" " + raw_line[min_indent:] + "\n") | |
| else: | |
| body_lines.append("\n") | |
| if not body_lines: | |
| body_lines = [" pass\n"] | |
| return "".join(before + body_lines + after) | |
| raise ValueError(f"Function {func_name!r} not found in source") | |
| # ββ DAG β text description (for prompts) βββββββββββββββββββββββββββββββββββββ | |
| def dag_to_text(dag: CodeDAG) -> str: | |
| """Render the DAG as a concise human-readable block for the agent prompt.""" | |
| lines: list[str] = [f"## Module: {dag.module_name}", "", "### Nodes"] | |
| for n in dag.nodes: | |
| if n.node_type == "module": | |
| lines.append(f"- [MODULE] {n.name}") | |
| elif n.node_type == "function": | |
| status = "[ STUB ]" if n.is_stub else "[ready ]" | |
| lines.append(f"- [FUNC] {status} {n.name}{n.signature}") | |
| elif n.node_type == "class": | |
| lines.append(f"- [CLASS] {n.name}") | |
| elif n.node_type == "imported_module": | |
| lines.append(f"- [IMPORT] {n.name}") | |
| lines += ["", "### Edges"] | |
| for e in dag.edges: | |
| lines.append(f"- {e.source} --{e.edge_type}--> {e.target}") | |
| return "\n".join(lines) | |