diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..149a453e2c06ad87de858c984702ed5f87027c15 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +*.pyc +*.pyo +.env +*.egg-info/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..f47f7e010e9d11ae9807dac8b0445505137b485c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +# Hugging Face Space Dockerfile. +# Mirrors the root Dockerfile, exists separately because HF Spaces looks for +# the Dockerfile inside the Space root by default. + +FROM python:3.11-slim + +WORKDIR /app + +COPY pyproject.toml ./ +COPY graphforge ./graphforge +COPY env ./env +COPY openenv.yaml ./ + +RUN pip install --no-cache-dir \ + "pydantic>=2.6" \ + "fastapi>=0.110" \ + "uvicorn[standard]>=0.27" \ + "httpx>=0.27" \ + "openenv-core>=0.1.0" \ + "pyyaml>=6.0" + +ENV PYTHONUNBUFFERED=1 +ENV PYTHONPATH=/app + +EXPOSE 7860 + +CMD ["uvicorn", "env.server:app", "--host", "0.0.0.0", "--port", "7860"] diff --git a/README.md b/README.md index 3cc7917620928b0e79b294f3799caa4552158a70..70a6c752fecbdcd5192a10d33cc76d8fef1975de 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,38 @@ --- -title: Graphforge Openenv -emoji: ๐Ÿ’ป -colorFrom: green +title: GraphForge OpenEnv +emoji: ๐Ÿงฑ +colorFrom: indigo colorTo: purple sdk: docker +app_port: 8000 pinned: false license: mit -short_description: A graph-first code-editing RL environment for Python repos. --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# GraphForge โ€” OpenEnv server + +Live deployment of the GraphForge environment for the Meta PyTorch OpenEnv +Hackathon. The server hosts the OpenEnv-compliant `/reset`, `/step`, `/state` +endpoints over HTTP. Anything that speaks the OpenEnv client protocol (or +plain JSON) can drive episodes. + +See the main project repo for the architecture overview, training notebook, +plots, and writeup. + +## Endpoints + +``` +POST /reset โ†’ GraphForgeObservation +POST /step { ... } โ†’ { observation, reward, done } +GET /state โ†’ GraphForgeState +GET /healthz +``` + +## Quick smoke test + +```bash +EID=$(curl -s -X POST $SPACE_URL/reset | python3 -c "import sys,json; print(json.load(sys.stdin)['episode_id'])") +curl -s -X POST $SPACE_URL/step -H 'content-type: application/json' \ + -d '{"kind": "add_module", "payload": {"name": "validators", "responsibility": "validation"}}' \ + | python3 -m json.tool +``` diff --git a/env/__init__.py b/env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22926036d1f56a4b6f462040d73b683b1d1d6a85 --- /dev/null +++ b/env/__init__.py @@ -0,0 +1,34 @@ +"""Multi-turn repo-editing OpenEnv environment. + +Public surface: + RepoEditAction, RepoEditObservation, RepoEditState โ€” wire models + RepoEditEnvironment โ€” OpenEnv environment + RepoEditEnv โ€” HTTP client +""" + +from env.actions import ( + AddNodeAction, + InspectAction, + QueryAction, + RemoveNodeAction, + RepoEditAction, + SubmitAction, + UpdateNodeAction, +) +from env.client import RepoEditEnv +from env.environment import RepoEditEnvironment +from env.models import RepoEditObservation, RepoEditState + +__all__ = [ + "AddNodeAction", + "InspectAction", + "QueryAction", + "RemoveNodeAction", + "RepoEditAction", + "RepoEditEnv", + "RepoEditEnvironment", + "RepoEditObservation", + "RepoEditState", + "SubmitAction", + "UpdateNodeAction", +] diff --git a/env/actions.py b/env/actions.py new file mode 100644 index 0000000000000000000000000000000000000000..82717404d0204eeaf29651c569c6c28037972a8d --- /dev/null +++ b/env/actions.py @@ -0,0 +1,90 @@ +"""Action schema for the multi-turn repo-editing environment. + +All actions are expressed as JSON dicts with a "kind" discriminator. +The agent emits one action per turn inside ... XML tags. + +Actions +------- +query Search the knowledge graph for relevant nodes. +inspect View the full source of a specific node. +add_node Insert a new function or class into a module/class. +update_node Replace the source of an existing node. +remove_node Delete a node from the graph. +submit Apply all pending changes, run tests, end the episode. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, ConfigDict + + +_cfg = ConfigDict(extra="forbid") + + +class QueryAction(BaseModel): + model_config = _cfg + kind: Literal["query"] = "query" + keywords: str + node_type: str = "all" # "all" | "function" | "class" | "module" | "method" + + +class InspectAction(BaseModel): + model_config = _cfg + kind: Literal["inspect"] = "inspect" + node_id: str + + +class AddNodeAction(BaseModel): + model_config = _cfg + kind: Literal["add_node"] = "add_node" + parent_id: str # node_id of the parent (module or class) + name: str # name of the new function/class + node_type: str # "function" | "class" + code: str # full source of the new node (incl. def/class line) + + +class UpdateNodeAction(BaseModel): + model_config = _cfg + kind: Literal["update_node"] = "update_node" + node_id: str # which node to replace + new_code: str # full replacement source (incl. def/class line) + + +class RemoveNodeAction(BaseModel): + model_config = _cfg + kind: Literal["remove_node"] = "remove_node" + node_id: str + + +class SubmitAction(BaseModel): + model_config = _cfg + kind: Literal["submit"] = "submit" + + +RepoEditAction = ( + QueryAction + | InspectAction + | AddNodeAction + | UpdateNodeAction + | RemoveNodeAction + | SubmitAction +) + + +def parse_action(raw: dict) -> RepoEditAction: + """Dispatch raw dict to the correct action model.""" + kind = raw.get("kind", "") + mapping = { + "query": QueryAction, + "inspect": InspectAction, + "add_node": AddNodeAction, + "update_node": UpdateNodeAction, + "remove_node": RemoveNodeAction, + "submit": SubmitAction, + } + cls = mapping.get(kind) + if cls is None: + raise ValueError(f"Unknown action kind: {kind!r}. Valid: {list(mapping)}") + return cls.model_validate(raw) diff --git a/env/ast_parser.py b/env/ast_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8a7e51d8f06020d084fa0558bdefd9c05aee3f --- /dev/null +++ b/env/ast_parser.py @@ -0,0 +1,249 @@ +"""AST-based DAG parser and code injection utilities. + +parse_source(source, module_name) -> CodeDAG + Parses a Python source string and returns a structured DAG with nodes + (module, function, imported_module) and typed edges (contains, calls, imports). + +inject_function_body(source, func_name, new_body) -> str + Replaces the body of func_name in source with new_body, preserving the + def line and any docstring. Used by the environment's step() method. +""" + +from __future__ import annotations + +import ast +from dataclasses import dataclass, field + + +# โ”€โ”€ DAG data model โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +@dataclass +class DAGNode: + name: str + node_type: str # "module" | "function" | "class" | "imported_module" + signature: str = "" + is_stub: bool = False + body_summary: str = "" + + +@dataclass +class DAGEdge: + edge_type: str # "contains" | "calls" | "imports" + source: str + target: str + + +@dataclass +class FunctionInfo: + name: str + signature: str + is_stub: bool + start_line: int # 1-indexed + end_line: int # 1-indexed, inclusive + has_docstring: bool + docstring_end_line: int # 1-indexed; == start_line when no docstring + + +@dataclass +class CodeDAG: + module_name: str + nodes: list[DAGNode] = field(default_factory=list) + edges: list[DAGEdge] = field(default_factory=list) + function_infos: dict[str, FunctionInfo] = field(default_factory=dict) + + def callers_of(self, func_name: str) -> list[str]: + return [e.source for e in self.edges if e.edge_type == "calls" and e.target == func_name] + + def callees_of(self, func_name: str) -> list[str]: + return [e.target for e in self.edges if e.edge_type == "calls" and e.source == func_name] + + def stub_functions(self) -> list[str]: + return [n.name for n in self.nodes if n.node_type == "function" and n.is_stub] + + +# โ”€โ”€ helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def _signature(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str: + parts = [] + for arg in node.args.args: + ann = f": {ast.unparse(arg.annotation)}" if arg.annotation else "" + parts.append(f"{arg.arg}{ann}") + ret = f" -> {ast.unparse(node.returns)}" if node.returns else "" + return f"({', '.join(parts)}){ret}" + + +def _is_stub(node: ast.FunctionDef | ast.AsyncFunctionDef, source: str) -> bool: + func_src = "\n".join(source.splitlines()[node.lineno - 1:node.end_lineno]) + if "# STUB" in func_src: + return True + # body that is just "raise NotImplementedError" + stmts = [s for s in node.body + if not (isinstance(s, ast.Expr) and isinstance(s.value, ast.Constant))] + if len(stmts) == 1 and isinstance(stmts[0], ast.Raise): + exc = stmts[0].exc + if isinstance(exc, ast.Name) and exc.id == "NotImplementedError": + return True + if isinstance(exc, ast.Call) and isinstance(exc.func, ast.Name) and exc.func.id == "NotImplementedError": + return True + return False + + +def _extract_calls(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]: + calls: set[str] = set() + for child in ast.walk(node): + if isinstance(child, ast.Call): + if isinstance(child.func, ast.Name): + calls.add(child.func.id) + return calls + + +# โ”€โ”€ main parser โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def parse_source(source: str, module_name: str = "module") -> CodeDAG: + """Parse Python source into a CodeDAG.""" + tree = ast.parse(source) + dag = CodeDAG(module_name=module_name) + dag.nodes.append(DAGNode(name=module_name, node_type="module")) + + func_names: set[str] = set() + + # imports + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + imp = alias.asname or alias.name + dag.nodes.append(DAGNode(name=imp, node_type="imported_module")) + dag.edges.append(DAGEdge("imports", module_name, imp)) + elif isinstance(node, ast.ImportFrom) and node.module: + dag.nodes.append(DAGNode(name=node.module, node_type="imported_module")) + dag.edges.append(DAGEdge("imports", module_name, node.module)) + + # top-level functions and classes + for node in tree.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + sig = _signature(node) + stub = _is_stub(node, source) + has_doc = ( + bool(node.body) + and isinstance(node.body[0], ast.Expr) + and isinstance(node.body[0].value, ast.Constant) + ) + doc_end = node.body[0].end_lineno if has_doc else node.lineno + + dag.nodes.append(DAGNode( + name=node.name, + node_type="function", + signature=sig, + is_stub=stub, + body_summary="STUB โ€” needs implementation" if stub else "(implemented)", + )) + dag.edges.append(DAGEdge("contains", module_name, node.name)) + dag.function_infos[node.name] = FunctionInfo( + name=node.name, + signature=sig, + is_stub=stub, + start_line=node.lineno, + end_line=node.end_lineno, + has_docstring=has_doc, + docstring_end_line=doc_end, + ) + func_names.add(node.name) + + elif isinstance(node, ast.ClassDef): + dag.nodes.append(DAGNode(name=node.name, node_type="class")) + dag.edges.append(DAGEdge("contains", module_name, node.name)) + for item in node.body: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + qname = f"{node.name}.{item.name}" + dag.nodes.append(DAGNode( + name=qname, + node_type="function", + signature=_signature(item), + is_stub=_is_stub(item, source), + )) + dag.edges.append(DAGEdge("contains", node.name, qname)) + func_names.add(qname) + + # call edges (same-module only) + for node in tree.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + for callee in _extract_calls(node): + if callee in func_names and callee != node.name: + dag.edges.append(DAGEdge("calls", node.name, callee)) + + return dag + + +# โ”€โ”€ code injection โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def inject_function_body(source: str, func_name: str, new_body: str) -> str: + """Replace the body of func_name in source with new_body. + + Preserves the def line and any docstring. new_body should be the raw body + text (with or without indentation โ€” we normalise it). + """ + tree = ast.parse(source) + lines = source.splitlines(keepends=True) + + for node in tree.body: + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if node.name != func_name: + continue + + # Determine where to keep up to (def line + optional docstring) + has_doc = ( + bool(node.body) + and isinstance(node.body[0], ast.Expr) + and isinstance(node.body[0].value, ast.Constant) + ) + keep_until = node.body[0].end_lineno if has_doc else node.lineno + # keep_until is 1-indexed; lines[:keep_until] gives 0..keep_until-1 + + before = lines[:keep_until] + after = lines[node.end_lineno:] # everything after the function + + # Normalise body indent: strip common leading whitespace, then re-add 4 spaces. + raw_lines = new_body.splitlines() + # find minimum indent of non-empty lines + min_indent = min( + (len(l) - len(l.lstrip()) for l in raw_lines if l.strip()), + default=0, + ) + body_lines: list[str] = [] + for raw_line in raw_lines: + if raw_line.strip(): + body_lines.append(" " + raw_line[min_indent:] + "\n") + else: + body_lines.append("\n") + + if not body_lines: + body_lines = [" pass\n"] + + return "".join(before + body_lines + after) + + raise ValueError(f"Function {func_name!r} not found in source") + + +# โ”€โ”€ DAG โ†’ text description (for prompts) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def dag_to_text(dag: CodeDAG) -> str: + """Render the DAG as a concise human-readable block for the agent prompt.""" + lines: list[str] = [f"## Module: {dag.module_name}", "", "### Nodes"] + + for n in dag.nodes: + if n.node_type == "module": + lines.append(f"- [MODULE] {n.name}") + elif n.node_type == "function": + status = "[ STUB ]" if n.is_stub else "[ready ]" + lines.append(f"- [FUNC] {status} {n.name}{n.signature}") + elif n.node_type == "class": + lines.append(f"- [CLASS] {n.name}") + elif n.node_type == "imported_module": + lines.append(f"- [IMPORT] {n.name}") + + lines += ["", "### Edges"] + for e in dag.edges: + lines.append(f"- {e.source} --{e.edge_type}--> {e.target}") + + return "\n".join(lines) diff --git a/env/client.py b/env/client.py new file mode 100644 index 0000000000000000000000000000000000000000..7d69d4c658cc5af002b73492e04ccacfdcbdbf37 --- /dev/null +++ b/env/client.py @@ -0,0 +1,36 @@ +"""HTTP client for the repo-editing environment.""" + +from __future__ import annotations + +from typing import Any + +import httpx + +from env.models import RepoEditObservation, RepoEditState + + +class RepoEditEnv: + def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 60.0) -> None: + self._client = httpx.Client(base_url=base_url.rstrip("/"), timeout=timeout) + + def reset(self, task_id: str | None = None) -> RepoEditObservation: + params = {"task_id": task_id} if task_id else {} + r = self._client.post("/reset", params=params) + r.raise_for_status() + return RepoEditObservation.model_validate(r.json()) + + def step(self, action_dict: dict[str, Any]) -> dict[str, Any]: + r = self._client.post("/step", json=action_dict) + r.raise_for_status() + return r.json() + + def state(self) -> RepoEditState: + r = self._client.get("/state") + r.raise_for_status() + return RepoEditState.model_validate(r.json()) + + def __enter__(self) -> "RepoEditEnv": + return self + + def __exit__(self, *_: object) -> None: + self._client.close() diff --git a/env/environment.py b/env/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..1f389ec7109b6497bb93792b14b0ee78781c1742 --- /dev/null +++ b/env/environment.py @@ -0,0 +1,467 @@ +"""Multi-turn repo-editing OpenEnv environment. + +Episode flow +------------ +reset() Parse the target repo into a KnowledgeGraph. Return an observation + containing the full graph overview and the task description. + +step() The agent emits one RepoEditAction per turn: + - query โ†’ search results (information, no graph mutation) + - inspect โ†’ full node source (information) + - add_node โ†’ insert new function/class into the live graph + - update_node โ†’ replace a node's source in the live graph + - remove_node โ†’ delete a node + - submit โ†’ materialise all changes back to disk (temp), run tests, + compute reward, end episode + +Reward structure (sparse โ€” designed for long-horizon RL) +--------------------------------------------------------- + Per-turn cost : -0.05 (forces efficiency) + Malformed action : -0.2 + On submit + all tests pass : +1.0 + partial pass : +0.5 * (n_pass / n_total) + compile error : 0.0 + Episode cap hit : 0.0 + +This sparse reward deliberately requires the agent to plan, navigate, and +execute across many turns โ€” it cannot succeed by guessing on the first turn. +""" + +from __future__ import annotations + +import ast +import json +import os +import re +import sys +import tempfile +import textwrap +import traceback +import uuid +from pathlib import Path +from typing import Any + +from env.actions import ( + AddNodeAction, + InspectAction, + QueryAction, + RemoveNodeAction, + RepoEditAction, + SubmitAction, + UpdateNodeAction, + parse_action, +) +from env.models import RepoEditObservation, RepoEditState +from env.tasks import SAMPLE_REPOS_DIR, TASK_BANK, RepoTask, all_task_ids, get_task +from graphforge.knowledge_graph import KGEdge, KGNode, KnowledgeGraph +from graphforge.repo_parser import parse_repo, _node_id + +try: + from openenv.core import Environment # type: ignore + _HAS_OPENENV = True +except Exception: + _HAS_OPENENV = False + from typing import Generic, TypeVar + A = TypeVar("A") + O = TypeVar("O") + S = TypeVar("S") + + class Environment(Generic[A, O, S]): # type: ignore[no-redef] + def reset(self) -> O: ... + def step(self, action: A) -> tuple[O, float, bool]: ... + def get_state(self) -> S: ... + + +# โ”€โ”€ constants โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +PER_TURN_COST = -0.05 +MALFORMED_PENALTY = -0.2 + + +# โ”€โ”€ materialiser (graph โ†’ disk) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def _materialise_changes( + kg: KnowledgeGraph, + repo_src_path: Path, + tmp_dir: str, +) -> dict[str, str]: + """Write mutated module sources to tmp_dir. Returns {rel_path: source}.""" + files: dict[str, str] = {} + for node in kg.all_nodes("module"): + if not node.file_path: + continue + # Re-assemble module source from its children's current sources + # For simplicity: use the node.source field (which we keep in sync) + files[node.file_path] = node.source + dest = Path(tmp_dir) / node.file_path + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_text(node.source, encoding="utf-8") + # Copy non-py files (like __init__.py markers) from original + for root, _, fnames in os.walk(str(repo_src_path)): + for fname in fnames: + if fname.endswith(".py"): + continue + src = Path(root) / fname + rel = src.relative_to(repo_src_path) + dst = Path(tmp_dir) / rel + dst.parent.mkdir(parents=True, exist_ok=True) + dst.write_bytes(src.read_bytes()) + return files + + +# โ”€โ”€ code injection into module source โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def _apply_add_node( + module_source: str, + code: str, + class_name: str | None = None, +) -> str: + """Insert code into module_source. + + If class_name is given, the code is indented and appended inside the class + body. Otherwise it is appended at module level. + """ + new_code = textwrap.dedent(code).strip() + if class_name is None: + return module_source.rstrip() + "\n\n\n" + new_code + "\n" + + # Insert indented method just before the end of the class block + indented = "\n".join(" " + line for line in new_code.splitlines()) + # Find the class definition via AST and splice + try: + tree = ast.parse(module_source) + lines = module_source.splitlines(keepends=True) + for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == class_name: + insert_at = node.end_lineno # 1-indexed, inclusive last line of class + before = "".join(lines[:insert_at]) + after = "".join(lines[insert_at:]) + return before.rstrip() + "\n\n" + indented + "\n" + after + except Exception: + pass + # Fallback: append at module level + return module_source.rstrip() + "\n\n\n" + indented + "\n" + + +def _apply_update_node( + module_source: str, + old_source: str, + new_code: str, +) -> str: + """Replace old_source verbatim in module_source with new_code.""" + new_code_clean = textwrap.dedent(new_code).strip() + if old_source in module_source: + return module_source.replace(old_source, new_code_clean, 1) + # Fallback: try stripping indentation differences + return module_source + "\n\n# PATCHED\n" + new_code_clean + "\n" + + +def _apply_remove_node(module_source: str, old_source: str) -> str: + if old_source in module_source: + return module_source.replace(old_source, "", 1) + return module_source + + +def _validate_python(source: str) -> tuple[bool, str]: + try: + ast.parse(source) + return True, "" + except SyntaxError as exc: + return False, str(exc) + + +# โ”€โ”€ environment โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class RepoEditEnvironment( + Environment[RepoEditAction, RepoEditObservation, RepoEditState] +): + """Multi-turn OpenEnv environment for repository-level code editing. + + The agent receives a Knowledge Graph of a real Python repo and must + navigate it to find the right location, then apply the correct edit. + Reward is sparse: only granted on a passing submit(). + """ + + def __init__(self, task_id: str | None = None) -> None: + self._configured_task_id = task_id + self._task: RepoTask | None = None + self._kg: KnowledgeGraph | None = None + self._episode_id: str | None = None + self._turn: int = 0 + self._done: bool = False + self._total_reward: float = 0.0 + self._history: list[dict[str, Any]] = [] + + # ----- OpenEnv contract --------------------------------------------------- + + def reset(self, task_id: str | None = None, task: Any = None) -> RepoEditObservation: + """Reset the environment. + + Pass either task_id (looks up TASK_BANK) or a task object directly + (supports AutoTask from graphforge.task_generator). + """ + if task is not None: + tid = task.task_id + else: + tid = task_id or self._configured_task_id or _pick_random_task() + task = TASK_BANK.get(tid) + if task is None: + raise ValueError(f"Unknown task_id: {tid!r}. Available: {all_task_ids()}") + + # Resolve the repo path: use task.repo_path if set, else fall back to sample_repos/ + if getattr(task, "repo_path", None): + repo_path = task.repo_path + else: + repo_path = str(SAMPLE_REPOS_DIR / task.repo_name) + + self._task = task + self._kg = parse_repo(repo_path) + self._episode_id = str(uuid.uuid4())[:8] + self._turn = 0 + self._done = False + self._total_reward = 0.0 + self._history = [] + + return RepoEditObservation( + episode_id=self._episode_id, + task_id=tid, + turn=0, + max_turns=task.max_turns, + graph_overview=self._kg.overview(), + task_description=task.description, + action_result="Episode started. Use query/inspect to navigate, then add_node/update_node to edit, then submit.", + done=False, + ) + + def step(self, action: RepoEditAction) -> tuple[RepoEditObservation, float, bool]: + if self._task is None or self._kg is None: + raise RuntimeError("step() called before reset()") + if self._done: + return self._terminal_obs("Episode already done."), 0.0, True + + self._turn += 1 + turn_reward = PER_TURN_COST + + # Dispatch + try: + result_text, extra_reward, done = self._dispatch(action) + turn_reward += extra_reward + except Exception as exc: + result_text = f"[ERROR] {exc}" + turn_reward += MALFORMED_PENALTY + done = False + + self._total_reward += turn_reward + + # Episode cap + if not done and self._turn >= self._task.max_turns: + done = True + result_text += f"\n[Episode cap reached: {self._task.max_turns} turns]" + + self._done = done + self._history.append({ + "turn": self._turn, + "action_kind": getattr(action, "kind", "unknown"), + "reward": turn_reward, + }) + + obs = RepoEditObservation( + episode_id=self._episode_id, + task_id=self._task.task_id, + turn=self._turn, + max_turns=self._task.max_turns, + graph_overview=self._kg.overview(), + task_description=self._task.description, + action_result=result_text, + turn_reward=turn_reward, + total_reward=self._total_reward, + done=done, + ) + return obs, turn_reward, done + + def get_state(self) -> RepoEditState: + return RepoEditState( + episode_id=self._episode_id, + task_id=self._task.task_id if self._task else None, + turn=self._turn, + done=self._done, + total_reward=self._total_reward, + ) + + @property + def state(self) -> RepoEditState: + return self.get_state() + + # ----- action dispatch ---------------------------------------------------- + + def _dispatch( + self, action: RepoEditAction + ) -> tuple[str, float, bool]: + """Returns (result_text, extra_reward, done).""" + kg = self._kg + assert kg is not None + + if isinstance(action, QueryAction): + nt = None if action.node_type == "all" else action.node_type + results = kg.search(action.keywords, node_type=nt) + if not results: + return f"No nodes found for query: {action.keywords!r}", 0.0, False + lines = [f"Found {len(results)} node(s) matching {action.keywords!r}:"] + for n in results[:10]: + lines.append(f" {n.node_id} ({n.file_path}:{n.line_start})") + return "\n".join(lines), 0.0, False + + if isinstance(action, InspectAction): + detail = kg.node_detail(action.node_id) + return detail, 0.0, False + + if isinstance(action, AddNodeAction): + parent = kg.get_node(action.parent_id) + if parent is None: + return f"[ERROR] parent_id {action.parent_id!r} not found.", MALFORMED_PENALTY, False + ok, err = _validate_python(action.code) + if not ok: + return f"[SYNTAX ERROR in your code] {err}", MALFORMED_PENALTY, False + + # Append to parent module's source + module_node = _find_module_for(kg, action.parent_id) + if module_node is None: + return f"[ERROR] could not find module for parent {action.parent_id!r}", MALFORMED_PENALTY, False + + parent_node = kg.get_node(action.parent_id) + class_name = parent_node.name if parent_node and parent_node.node_type == "class" else None + module_node.source = _apply_add_node(module_node.source, action.code, class_name=class_name) + + # Register the new node in the KG + ntype = action.node_type if action.node_type in ("function", "class", "method") else "function" + new_id = _node_id(ntype, module_node.file_path, action.name) + new_node = KGNode( + node_id=new_id, + node_type=ntype, + name=action.name, + file_path=module_node.file_path, + line_start=module_node.line_end, + line_end=module_node.line_end + action.code.count("\n") + 1, + source=textwrap.dedent(action.code).strip(), + ) + kg.insert_node(action.parent_id, new_node) + return f"Added {ntype} `{action.name}` to `{module_node.file_path}`.\nNew node_id: {new_id}", 0.0, False + + if isinstance(action, UpdateNodeAction): + target = kg.get_node(action.node_id) + if target is None: + return f"[ERROR] node_id {action.node_id!r} not found.", MALFORMED_PENALTY, False + ok, err = _validate_python(action.new_code) + if not ok: + return f"[SYNTAX ERROR in your code] {err}", MALFORMED_PENALTY, False + + module_node = _find_module_for(kg, action.node_id) + if module_node is None: + return f"[ERROR] could not find module for {action.node_id!r}", MALFORMED_PENALTY, False + + old_source = target.source + module_node.source = _apply_update_node(module_node.source, old_source, action.new_code) + target.source = textwrap.dedent(action.new_code).strip() + return f"Updated `{action.node_id}`.", 0.0, False + + if isinstance(action, RemoveNodeAction): + target = kg.get_node(action.node_id) + if target is None: + return f"[ERROR] node_id {action.node_id!r} not found.", MALFORMED_PENALTY, False + module_node = _find_module_for(kg, action.node_id) + if module_node: + module_node.source = _apply_remove_node(module_node.source, target.source) + kg.remove_node(action.node_id) + return f"Removed `{action.node_id}`.", 0.0, False + + if isinstance(action, SubmitAction): + return self._run_submit() + + return f"[ERROR] unrecognised action type: {type(action)}", MALFORMED_PENALTY, False + + def _run_submit(self) -> tuple[str, float, bool]: + """Write modified sources to a temp dir, run tests there, clean up.""" + kg = self._kg + task = self._task + assert kg is not None and task is not None + + reward, msg = _run_tests_in_tempdir(kg, task.test_code, task.repo_name) + return f"[SUBMIT RESULT]\n{msg}", reward, True + + def _terminal_obs(self, msg: str) -> RepoEditObservation: + return RepoEditObservation( + episode_id=self._episode_id, + task_id=self._task.task_id if self._task else None, + turn=self._turn, + max_turns=self._task.max_turns if self._task else 0, + graph_overview="", + task_description="", + action_result=msg, + done=True, + total_reward=self._total_reward, + ) + + +# โ”€โ”€ helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def _find_module_for(kg: KnowledgeGraph, node_id: str) -> KGNode | None: + """Walk up the parent chain until we hit a module node.""" + current_id = node_id + seen: set[str] = set() + while current_id and current_id not in seen: + seen.add(current_id) + node = kg.get_node(current_id) + if node and node.node_type == "module": + return node + parent = kg.parent_of(current_id) + if parent is None: + break + current_id = parent.node_id + return None + + +def _run_tests_in_tempdir( + kg: KnowledgeGraph, test_code: str, pkg_name: str +) -> tuple[float, str]: + """Write mutated module sources to a temp dir, import from there, run tests. + + This works for ANY Python repo โ€” no hardcoded package paths needed. + The test_code must use short imports: `from . import ...` + """ + with tempfile.TemporaryDirectory() as tmpdir: + pkg_dir = Path(tmpdir) / pkg_name + pkg_dir.mkdir(parents=True) + (pkg_dir / "__init__.py").write_text("") + + # Write each module's current (potentially mutated) source + for node in kg.all_nodes("module"): + if not node.file_path or node.file_path == "__init__.py": + continue + dest = pkg_dir / node.file_path + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_text(node.source, encoding="utf-8") + + # Remove any stale cached copies of this package + stale = [k for k in sys.modules if k == pkg_name or k.startswith(pkg_name + ".")] + for k in stale: + del sys.modules[k] + + sys.path.insert(0, tmpdir) + try: + exec(compile(test_code, "", "exec"), {}) # noqa: S102 + return 1.0, "โœ“ All tests passed!" + except AssertionError as exc: + return 0.0, f"โœ— Test failed: {exc}" + except Exception: + return 0.0, f"โœ— Exception during tests:\n{traceback.format_exc(limit=5)}" + finally: + sys.path.remove(tmpdir) + stale = [k for k in sys.modules if k == pkg_name or k.startswith(pkg_name + ".")] + for k in stale: + del sys.modules[k] + + +def _pick_random_task() -> str: + import random + return random.choice(all_task_ids()) diff --git a/env/models.py b/env/models.py new file mode 100644 index 0000000000000000000000000000000000000000..0d5fd9fe80a50acae0ce5d41c5301ae5c7a0f69f --- /dev/null +++ b/env/models.py @@ -0,0 +1,46 @@ +"""Pydantic wire models for the multi-turn repo-editing environment.""" + +from __future__ import annotations + +from typing import Any, Optional + +from pydantic import BaseModel, ConfigDict, Field + +_cfg = ConfigDict(extra="ignore") + + +class RepoEditObservation(BaseModel): + """What the env returns after reset() or step(). + + Contains the current graph overview + the result of the last action. + The agent should read action_result carefully before deciding the next step. + """ + + model_config = _cfg + + episode_id: Optional[str] = None + task_id: Optional[str] = None + turn: int = 0 + max_turns: int = 15 + + graph_overview: str = "" # compact text view of the entire repo KG + task_description: str = "" # what the agent needs to accomplish + action_result: str = "" # feedback from the last action + + turn_reward: float = 0.0 + total_reward: float = 0.0 + done: bool = False + + info: dict[str, Any] = Field(default_factory=dict) + + +class RepoEditState(BaseModel): + """Episode-level state snapshot.""" + + model_config = _cfg + + episode_id: Optional[str] = None + task_id: Optional[str] = None + turn: int = 0 + done: bool = False + total_reward: float = 0.0 diff --git a/env/server.py b/env/server.py new file mode 100644 index 0000000000000000000000000000000000000000..60361013b7257977ef44d6de2e3471758fb59af1 --- /dev/null +++ b/env/server.py @@ -0,0 +1,44 @@ +"""FastAPI server for the multi-turn repo-editing environment.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import FastAPI, HTTPException + +from env.actions import RepoEditAction, parse_action +from env.environment import RepoEditEnvironment +from env.models import RepoEditObservation, RepoEditState + +_env = RepoEditEnvironment() + + +def _make_app() -> FastAPI: + app = FastAPI(title="Repo-Edit OpenEnv", version="0.3.0") + + @app.post("/reset", response_model=RepoEditObservation) + def reset(task_id: str | None = None) -> RepoEditObservation: + return _env.reset(task_id=task_id) + + @app.post("/step") + def step(action_dict: dict[str, Any]) -> dict[str, Any]: + try: + action = parse_action(action_dict) + obs, reward, done = _env.step(action) + except (ValueError, RuntimeError) as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + return {"observation": obs.model_dump(), "reward": reward, "done": done} + + @app.get("/state", response_model=RepoEditState) + def state() -> RepoEditState: + return _env.get_state() + + @app.get("/healthz") + def healthz() -> dict[str, Any]: + return {"status": "ok"} + + return app + + +app = _make_app() +__all__ = ["app"] diff --git a/env/tasks.py b/env/tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ef484d5d0ae069d487081c7bd0c0bc73c8eb48 --- /dev/null +++ b/env/tasks.py @@ -0,0 +1,363 @@ +"""Multi-turn repo-editing tasks. + +Each Task specifies: + - A target repo to work on (points to a sample_repos/ subdir) + - A natural-language description of the change to make + - A set of test functions (Python code strings) that verify the change + - The maximum number of turns allowed + +Training tasks are deliberately structured to require multi-step navigation: + 1. The agent must QUERY the graph to find relevant nodes + 2. INSPECT nodes to understand the existing code + 3. ADD or UPDATE nodes to implement the change + 4. SUBMIT to trigger compilation + test execution + +This sparse reward structure forces the agent to develop structured planning +and state tracking across long trajectories โ€” the core theme of this project. +""" + +from __future__ import annotations + +import importlib.util +import sys +import textwrap +import traceback +from dataclasses import dataclass, field +from pathlib import Path + + +SAMPLE_REPOS_DIR = Path(__file__).resolve().parent.parent / "graphforge" / "sample_repos" + + +@dataclass +class RepoTask: + task_id: str + repo_name: str # package name (used as tempdir subdir) + description: str # natural-language task for the agent + test_code: str # Python assertions using short imports + max_turns: int = 15 + difficulty: int = 0 # 0=easy, 1=medium, 2=hard + hints: list[str] = field(default_factory=list) + repo_path: str | None = None # if set, full path to repo source dir + + +TASK_BANK: dict[str, RepoTask] = {} + + +def _reg(task: RepoTask) -> RepoTask: + TASK_BANK[task.task_id] = task + return task + + +# โ”€โ”€ Task 0: add validate_due_date โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +_reg(RepoTask( + task_id="t0.validate_due_date", + repo_name="task_manager", + description=textwrap.dedent("""\ + Add a function `validate_due_date(due_date) -> bool` to `validators.py`. + + The function should return True if: + - due_date is None (no deadline), OR + - due_date is a datetime.date instance + + It should return False for any other type (strings, integers, etc.). + """).strip(), + test_code=textwrap.dedent("""\ + from datetime import date + from task_manager.validators import validate_due_date + assert validate_due_date(None) is True, "None is valid (no deadline)" + assert validate_due_date(date(2025, 1, 1)) is True, "date object is valid" + assert validate_due_date("2025-01-01") is False, "string is not valid" + assert validate_due_date(20250101) is False, "int is not valid" + assert validate_due_date([]) is False, "list is not valid" + """).strip(), + max_turns=12, + hints=[ + "Look in validators.py to see the style of existing validators.", + "The function signature should be: def validate_due_date(due_date) -> bool", + "Import datetime.date inside the function or at the top of validators.py.", + ], +)) + +# โ”€โ”€ Task 1: add Task.is_overdue โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +_reg(RepoTask( + task_id="t1.is_overdue", + repo_name="task_manager", + description=textwrap.dedent("""\ + Add a method `is_overdue(self, today: date) -> bool` to the `Task` + class in `models.py`. + + The method should return True if: + - the task has a due_date AND + - today is strictly after the due_date AND + - the task is not yet done + + It should return False if there is no due_date, or if the task is done, + or if today <= due_date. + """).strip(), + test_code=textwrap.dedent("""\ + from datetime import date + from task_manager.models import Task + + t_past = Task("x", "low", [], due_date=date(2020, 1, 1)) + t_future = Task("y", "low", [], due_date=date(2099, 1, 1)) + t_none = Task("z", "low", [], due_date=None) + t_done = Task("d", "low", [], due_date=date(2020, 1, 1)) + t_done.complete() + + today = date.today() + assert t_past.is_overdue(today) is True, "past due date โ†’ overdue" + assert t_future.is_overdue(today) is False, "future due date โ†’ not overdue" + assert t_none.is_overdue(today) is False, "no due date โ†’ not overdue" + assert t_done.is_overdue(today) is False, "done task โ†’ not overdue" + """).strip(), + max_turns=15, + difficulty=1, + hints=[ + "The Task class is in models.py.", + "The method should check self.due_date, today, and self.done.", + ], +)) + +# โ”€โ”€ Task 2: add TaskStore.find_by_tag โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +_reg(RepoTask( + task_id="t2.find_by_tag", + repo_name="task_manager", + description=textwrap.dedent("""\ + Add a method `find_by_tag(self, tag: str) -> list[Task]` to the + `TaskStore` class in `storage.py`. + + The method should return a list of all tasks that have `tag` in + their `tags` list. Return an empty list if no tasks match. + """).strip(), + test_code=textwrap.dedent("""\ + from task_manager.models import Task + from task_manager.storage import TaskStore + + store = TaskStore() + store.add(Task("t1", "high", ["python", "backend"], None)) + store.add(Task("t2", "low", ["frontend"], None)) + store.add(Task("t3", "medium", ["python"], None)) + + result = store.find_by_tag("python") + assert len(result) == 2, f"Expected 2, got {len(result)}" + titles = {t.title for t in result} + assert titles == {"t1", "t3"}, f"Wrong titles: {titles}" + + empty = store.find_by_tag("devops") + assert empty == [], f"Expected [], got {empty}" + """).strip(), + max_turns=15, + difficulty=1, +)) + +# โ”€โ”€ Task 3 (hard): enforce priority validation in api.create_task โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +_reg(RepoTask( + task_id="t3.enforce_priority", + repo_name="task_manager", + description=textwrap.dedent("""\ + Update the `create_task` function in `api.py` so that it validates + the `priority` argument using `validate_priority` from `validators.py`. + + If the priority is invalid, raise `ValueError` with a clear message. + The existing validations for title and tags must still work. + + Note: `validate_priority` already exists in validators.py. + You must import and call it inside `create_task`. + """).strip(), + test_code=textwrap.dedent("""\ + from task_manager import api as _api + _api.reset_store() # clean state between runs + + # valid priority passes through + t = _api.create_task("Buy milk", priority="high") + assert t.priority == "high" + + # invalid priority raises ValueError + raised = False + try: + _api.create_task("Bad task", priority="urgent") + except ValueError: + raised = True + assert raised, "create_task should raise ValueError for invalid priority" + + # title validation still works + raised2 = False + try: + _api.create_task("", priority="low") + except ValueError: + raised2 = True + assert raised2, "create_task should still reject empty title" + """).strip(), + max_turns=18, + difficulty=2, + hints=[ + "api.py already imports validate_title and validate_tags from validators.", + "You need to also import validate_priority and call it in create_task.", + ], +)) + + +# โ”€โ”€ Humanize tasks (real-world library) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +_reg(RepoTask( + task_id="t4.intpercent", + repo_name="humanize", + description=textwrap.dedent("""\ + Add a function `intpercent(value: float, decimal_places: int = 1) -> str` + to `number.py`. + + The function should convert a fraction to a percentage string: + 0.0 โ†’ "0.0%" + 0.5 โ†’ "50.0%" + 0.753 โ†’ "75.3%" + 1.0 โ†’ "100.0%" + + Use `decimal_places` to control how many digits appear after the decimal. + If decimal_places=0, return an integer percentage with no decimal point. + """).strip(), + test_code=textwrap.dedent("""\ + from humanize.number import intpercent + assert intpercent(0.0) == "0.0%", f"got {intpercent(0.0)!r}" + assert intpercent(0.5) == "50.0%", f"got {intpercent(0.5)!r}" + assert intpercent(0.753) == "75.3%", f"got {intpercent(0.753)!r}" + assert intpercent(1.0) == "100.0%", f"got {intpercent(1.0)!r}" + assert intpercent(0.5, decimal_places=0) == "50%", f"got {intpercent(0.5, decimal_places=0)!r}" + """).strip(), + max_turns=12, + difficulty=0, + hints=[ + "Look at number.py โ€” the existing functions show the style to follow.", + "Use f-string formatting: f'{value * 100:.{decimal_places}f}%'", + ], +)) + +_reg(RepoTask( + task_id="t5.naturalfilecount", + repo_name="humanize", + description=textwrap.dedent("""\ + Add a function `naturalfilecount(n: int) -> str` to `filesize.py`. + + The function should return a human-readable file count: + 0 โ†’ "no files" + 1 โ†’ "1 file" + 2 โ†’ "2 files" + 99 โ†’ "99 files" + """).strip(), + test_code=textwrap.dedent("""\ + from humanize.filesize import naturalfilecount + assert naturalfilecount(0) == "no files", f"got {naturalfilecount(0)!r}" + assert naturalfilecount(1) == "1 file", f"got {naturalfilecount(1)!r}" + assert naturalfilecount(2) == "2 files", f"got {naturalfilecount(2)!r}" + assert naturalfilecount(99) == "99 files", f"got {naturalfilecount(99)!r}" + """).strip(), + max_turns=12, + difficulty=0, + hints=[ + "Look at filesize.py โ€” naturalsize is the only function there.", + "This is a short function: handle n==0, n==1, and n>1 as three cases.", + ], +)) + +_reg(RepoTask( + task_id="t6.metric", + repo_name="humanize", + description=textwrap.dedent("""\ + Add a function `metric(value: float, unit: str = "") -> str` to `number.py`. + + The function should format a number using SI metric prefixes: + 1_500_000 โ†’ "1.5 M" + 2_000 โ†’ "2.0 k" + 500 โ†’ "500" (no prefix below 1000) + + Supported prefixes (largest to smallest): T (10ยนยฒ), G (10โน), M (10โถ), k (10ยณ). + If a unit is provided, append it after the prefix: metric(1500, "Hz") โ†’ "1.5 kHz". + Always format the scaled number to 1 decimal place. + """).strip(), + test_code=textwrap.dedent("""\ + from humanize.number import metric + assert metric(1_500_000) == "1.5 M", f"got {metric(1_500_000)!r}" + assert metric(2_000) == "2.0 k", f"got {metric(2_000)!r}" + assert metric(500) == "500", f"got {metric(500)!r}" + assert metric(1_500, "Hz") == "1.5 kHz", f"got {metric(1_500, 'Hz')!r}" + assert metric(2e9, "W") == "2.0 GW", f"got {metric(2e9, 'W')!r}" + """).strip(), + max_turns=15, + difficulty=1, + hints=[ + "Loop through prefixes from largest to smallest: (1e12,'T'), (1e9,'G'), (1e6,'M'), (1e3,'k').", + "If abs(value) >= threshold, scale and format; otherwise return str(int(value)).", + ], +)) + +_reg(RepoTask( + task_id="t7.age", + repo_name="humanize", + description=textwrap.dedent("""\ + Add a function `age(birth_date) -> str` to `time.py`. + + The function receives a `datetime.date` and returns a human-readable age: + - If the person is under 1 year old, return "X months old" (use 30-day months). + - If exactly 1 year, return "1 year old". + - Otherwise return "X years old". + + Use `datetime.date.today()` as the reference point. + Assume birth_date is always a valid date in the past. + """).strip(), + test_code=textwrap.dedent("""\ + import datetime as dt + from humanize.time import age + + today = dt.date.today() + dob_25y = today.replace(year=today.year - 25) + dob_1y = today.replace(year=today.year - 1) + dob_6m = today - dt.timedelta(days=182) + dob_2m = today - dt.timedelta(days=61) + + assert age(dob_25y) == "25 years old", f"got {age(dob_25y)!r}" + assert age(dob_1y) == "1 year old", f"got {age(dob_1y)!r}" + assert age(dob_6m) == "6 months old", f"got {age(dob_6m)!r}" + assert age(dob_2m) == "2 months old", f"got {age(dob_2m)!r}" + """).strip(), + max_turns=15, + difficulty=1, + hints=[ + "import datetime as dt is already at the top of time.py.", + "days = (dt.date.today() - birth_date).days; years = days // 365; months = days // 30", + ], +)) + + +# โ”€โ”€ test runner โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def run_tests(task: RepoTask) -> tuple[bool, str]: + """Execute task.test_code and return (passed, message).""" + # Reload all task_manager modules to pick up any source-level changes + _reload_task_manager() + try: + exec(compile(task.test_code, "", "exec"), {}) # noqa: S102 + return True, "All assertions passed." + except AssertionError as exc: + return False, f"AssertionError: {exc}" + except Exception: + return False, traceback.format_exc(limit=5) + + +def _reload_task_manager() -> None: + """Force-reload all task_manager submodules so edits take effect.""" + prefix = "graphforge.sample_repos.task_manager" + to_reload = [k for k in sys.modules if k.startswith(prefix)] + for mod_name in to_reload: + del sys.modules[mod_name] + + +def all_task_ids() -> list[str]: + return list(TASK_BANK.keys()) + + +def get_task(task_id: str) -> RepoTask | None: + return TASK_BANK.get(task_id) diff --git a/graphforge/__init__.py b/graphforge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..785c7afa9ff4f84120664d4afb9cd98f5e0b2d03 --- /dev/null +++ b/graphforge/__init__.py @@ -0,0 +1,24 @@ +"""GraphForge โ€” graph-first code generation environment for long-horizon RL. + +The agent constructs Python programs by mutating a typed function-call graph; +source files are a deterministic projection of the canonical graph. + +Top-level subsystems: + graph canonical graph schema (Modules, Nodes, Edges) + actions eleven-action surface, atomic dispatcher with rollback + types signature parser + edge type-flow validator + templates ~25-template body library, parameterized + materializer graph -> Python source + parser Python source -> graph (round-trip) + validator parse / import / mypy --strict gate + behavioral hypothesis-based property test runner + constraints per-kind constraint checker dispatch + reward reward engine (per-turn + terminal) + tasks task bank + variant generator + server FastAPI OpenEnv server + training GRPO multi-turn rollout + +See README.md for design rationale and PROPOSAL.md for the full spec. +""" + +__version__ = "0.0.1" diff --git a/graphforge/actions/__init__.py b/graphforge/actions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e5b5bd770d975615e7daff3f36863d9ee89d2e9 --- /dev/null +++ b/graphforge/actions/__init__.py @@ -0,0 +1,15 @@ +"""Action surface for GraphForge. + +Public API: + + from graphforge.actions import dispatch, ActionResult + from graphforge.actions.schema import Action, AddNode, ... + from graphforge.actions.errors import ActionError + +See PROPOSAL.md ยง4 for the full action vocabulary. +""" + +from graphforge.actions.dispatcher import ActionResult, dispatch +from graphforge.actions.errors import ActionError + +__all__ = ["ActionError", "ActionResult", "dispatch"] diff --git a/graphforge/actions/dispatcher.py b/graphforge/actions/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..2fbcbeb03f2833fc93eb4d562ee20aacc1c60789 --- /dev/null +++ b/graphforge/actions/dispatcher.py @@ -0,0 +1,442 @@ +"""Atomic action dispatcher. + +Applies an :class:`Action` to a :class:`Graph`. Every mutation is atomic: +the dispatcher snapshots the graph before the handler runs and restores it on +any failure. Failures surface as :class:`ActionError` with a stable code, never +as silent partial state. + +Information actions (query_*, materialize_*, run_*) are routed but their +implementations live in their respective subsystems and are stubbed for now. +``submit`` returns a sentinel so the episode runner can recognize termination. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from graphforge.actions import errors as E +from graphforge.actions.schema import ( + Action, + AddEdge, + AddModule, + AddNode, + AttachBody, + MaterializeAndValidate, + QuerySpec, + QuerySubgraph, + QueryTypes, + RemoveEdge, + RemoveModule, + RemoveNode, + RunBehavioralTests, + SetNodeModule, + Submit, +) +from graphforge.actions.signature import parse_signature +from graphforge.graph.schema import ( + ArgMapping, + Edge, + Graph, + Module, + Node, +) +from graphforge.templates import get_template, validate_args + + +# ---- result envelope ------------------------------------------------- + + +@dataclass +class ActionResult: + """Envelope returned by :func:`dispatch`.""" + + ok: bool + payload: dict[str, Any] + terminal: bool = False + + @classmethod + def success(cls, **payload: Any) -> "ActionResult": + return cls(ok=True, payload=payload, terminal=False) + + @classmethod + def failure(cls, err: E.ActionError) -> "ActionResult": + return cls(ok=False, payload=err.to_dict(), terminal=False) + + @classmethod + def terminate(cls, **payload: Any) -> "ActionResult": + return cls(ok=True, payload=payload, terminal=True) + + +# ---- dispatcher ------------------------------------------------------ + + +def dispatch(graph: Graph, action: Action) -> ActionResult: + """Apply ``action`` to ``graph`` in place. Atomic on failure. + + On any handler exception (including :class:`ActionError`) the graph is + rolled back to the pre-call snapshot. + """ + snap = graph.snapshot() + try: + return _route(graph, action) + except E.ActionError as err: + _restore(graph, snap) + return ActionResult.failure(err) + except Exception as exc: # pragma: no cover โ€” unexpected handler bug + _restore(graph, snap) + return ActionResult.failure( + E.ActionError(E.SCHEMA_REJECTION, f"unhandled: {exc}") + ) + + +def _restore(graph: Graph, snap: Graph) -> None: + graph.modules = snap.modules + graph.nodes = snap.nodes + graph.edges = snap.edges + + +def _route(graph: Graph, action: Action) -> ActionResult: + # Mutations + if isinstance(action, AddModule): + return _h_add_module(graph, action) + if isinstance(action, RemoveModule): + return _h_remove_module(graph, action) + if isinstance(action, AddNode): + return _h_add_node(graph, action) + if isinstance(action, RemoveNode): + return _h_remove_node(graph, action) + if isinstance(action, SetNodeModule): + return _h_set_node_module(graph, action) + if isinstance(action, AttachBody): + return _h_attach_body(graph, action) + if isinstance(action, AddEdge): + return _h_add_edge(graph, action) + if isinstance(action, RemoveEdge): + return _h_remove_edge(graph, action) + # Information (delegated; stubs for now) + if isinstance(action, QuerySpec): + return _h_query_spec(graph, action) + if isinstance(action, QuerySubgraph): + return _h_query_subgraph(graph, action) + if isinstance(action, QueryTypes): + return _h_query_types(graph, action) + if isinstance(action, MaterializeAndValidate): + return _h_materialize(graph, action) + if isinstance(action, RunBehavioralTests): + return _h_run_tests(graph, action) + if isinstance(action, Submit): + return _h_submit(graph, action) + raise E.ActionError(E.SCHEMA_REJECTION, f"unknown action: {type(action).__name__}") + + +# ---- mutation handlers ---------------------------------------------- + + +def _h_add_module(graph: Graph, a: AddModule) -> ActionResult: + if graph.find_module(a.name) is not None: + raise E.ActionError( + E.NAME_COLLISION, f"module {a.name!r} already exists", name=a.name + ) + graph.modules.append(Module(name=a.name, responsibility=a.responsibility)) + return ActionResult.success(added_module=a.name) + + +def _h_remove_module(graph: Graph, a: RemoveModule) -> ActionResult: + mod = graph.find_module(a.name) + if mod is None: + raise E.ActionError(E.UNKNOWN_MODULE, f"module {a.name!r} does not exist", name=a.name) + if any(n.module == a.name for n in graph.nodes): + raise E.ActionError( + E.MODULE_NOT_EMPTY, + f"module {a.name!r} still contains nodes", + name=a.name, + node_count=sum(1 for n in graph.nodes if n.module == a.name), + ) + graph.modules = [m for m in graph.modules if m.name != a.name] + return ActionResult.success(removed_module=a.name) + + +def _h_add_node(graph: Graph, a: AddNode) -> ActionResult: + if graph.find_module(a.module) is None: + raise E.ActionError(E.UNKNOWN_MODULE, f"module {a.module!r} does not exist", name=a.module) + if graph.find_node(a.name, a.module) is not None: + raise E.ActionError( + E.NAME_COLLISION, + f"node {a.module}.{a.name} already exists", + name=a.name, + module=a.module, + ) + # Surface signature parse โ€” catches errors that the pydantic regex misses. + try: + parse_signature(a.signature) + except ValueError as ve: + raise E.ActionError(E.SCHEMA_REJECTION, str(ve), signature=a.signature) from ve + decl_order = max((n.decl_order for n in graph.nodes), default=-1) + 1 + graph.nodes.append( + Node( + name=a.name, + module=a.module, + signature=a.signature, + purity=a.purity, + error_policy=a.error_policy, + decl_order=decl_order, + ) + ) + return ActionResult.success(added_node=f"{a.module}.{a.name}", decl_order=decl_order) + + +def _h_remove_node(graph: Graph, a: RemoveNode) -> ActionResult: + n = graph.find_node(a.name, a.module) + if n is None: + raise E.ActionError( + E.UNKNOWN_NODE, f"node {a.module}.{a.name} does not exist", name=a.name, module=a.module + ) + qn = n.qualified_name + refs = [e for e in graph.edges if e.caller == qn or e.callee == qn] + if refs: + raise E.ActionError( + E.NODE_HAS_REFERENCES, + f"node {qn} is referenced by {len(refs)} edge(s)", + name=a.name, + module=a.module, + referencing_edges=[(e.caller, e.callee) for e in refs], + ) + graph.nodes = [m for m in graph.nodes if not (m.name == a.name and m.module == a.module)] + return ActionResult.success(removed_node=qn) + + +def _h_set_node_module(graph: Graph, a: SetNodeModule) -> ActionResult: + n = graph.find_node(a.name, a.current_module) + if n is None: + raise E.ActionError( + E.UNKNOWN_NODE, + f"node {a.current_module}.{a.name} does not exist", + name=a.name, + module=a.current_module, + ) + new_mod = graph.find_module(a.new_module) + if new_mod is None: + raise E.ActionError( + E.UNKNOWN_MODULE, + f"target module {a.new_module!r} does not exist", + name=a.new_module, + ) + if graph.find_node(a.name, a.new_module) is not None: + raise E.ActionError( + E.NAME_COLLISION, + f"node named {a.name!r} already exists in {a.new_module!r}", + name=a.name, + module=a.new_module, + ) + old_qn = n.qualified_name + new_qn = f"{a.new_module}.{a.name}" + n.module = a.new_module + # Rewrite edge endpoints that referred to the old qualified name. + for e in graph.edges: + if e.caller == old_qn: + e.caller = new_qn + if e.callee == old_qn: + e.callee = new_qn + # Post-condition: rewriting must not have introduced an import cycle. + if graph.has_module_cycle(): + raise E.ActionError( + E.WOULD_CREATE_CYCLE, + f"moving {old_qn} -> {new_qn} would create an import cycle", + from_qn=old_qn, + to_qn=new_qn, + ) + return ActionResult.success(moved_node={"from": old_qn, "to": new_qn}) + + +def _h_attach_body(graph: Graph, a: AttachBody) -> ActionResult: + n = graph.find_node(a.name, a.module) + if n is None: + raise E.ActionError( + E.UNKNOWN_NODE, + f"node {a.module}.{a.name} does not exist", + name=a.name, + module=a.module, + ) + spec = get_template(a.template) + if spec is None: + raise E.ActionError( + E.UNKNOWN_TEMPLATE, f"unknown template {a.template!r}", template=a.template + ) + problems = validate_args(a.template, a.args) + if problems: + raise E.ActionError( + E.TEMPLATE_ARGS_INVALID, + f"args invalid for template {a.template!r}: {'; '.join(problems)}", + template=a.template, + problems=problems, + ) + out_d = graph.fan_out(n.qualified_name) + in_d = graph.fan_in(n.qualified_name) + if not spec.edges_ok(out_d, in_d): + raise E.ActionError( + E.TEMPLATE_ARGS_INVALID, + f"template {a.template!r} requires different edge structure " + f"(out_d={out_d}, in_d={in_d})", + template=a.template, + out_degree=out_d, + in_degree=in_d, + ) + n.body_template = a.template + n.body_template_args = dict(a.args) + return ActionResult.success( + attached={"node": n.qualified_name, "template": a.template} + ) + + +def _h_add_edge(graph: Graph, a: AddEdge) -> ActionResult: + caller = graph.find_node_qualified(a.caller) + callee = graph.find_node_qualified(a.callee) + if caller is None: + raise E.ActionError(E.UNKNOWN_NODE, f"caller {a.caller!r} does not exist", node=a.caller) + if callee is None: + raise E.ActionError(E.UNKNOWN_NODE, f"callee {a.callee!r} does not exist", node=a.callee) + if graph.find_edge(a.caller, a.callee) is not None: + raise E.ActionError( + E.DUPLICATE_EDGE, + f"edge {a.caller} -> {a.callee} already exists", + caller=a.caller, + callee=a.callee, + ) + # Validate arg_mapping covers all required parameters of callee. + callee_sig = parse_signature(callee.signature) + caller_sig = parse_signature(caller.signature) + mapped_callee = {m.callee_param for m in a.arg_mapping} + mapped_caller = {m.caller_arg for m in a.arg_mapping} + missing = set(callee_sig.required_params) - mapped_callee + if missing: + raise E.ActionError( + E.ARG_MAPPING_INVALID, + f"arg_mapping is missing required callee params: {sorted(missing)}", + missing=sorted(missing), + ) + bogus_callee = mapped_callee - set(callee_sig.all_params) + if bogus_callee: + raise E.ActionError( + E.ARG_MAPPING_INVALID, + f"arg_mapping references unknown callee params: {sorted(bogus_callee)}", + unknown=sorted(bogus_callee), + ) + bogus_caller = mapped_caller - set(caller_sig.all_params) + if bogus_caller: + raise E.ActionError( + E.ARG_MAPPING_INVALID, + f"arg_mapping references unknown caller args: {sorted(bogus_caller)}", + unknown=sorted(bogus_caller), + ) + # Add tentatively; check post-condition. + graph.edges.append( + Edge( + caller=a.caller, + callee=a.callee, + arg_mapping=[ArgMapping(**m.model_dump()) for m in a.arg_mapping], + ) + ) + if graph.has_module_cycle(): + raise E.ActionError( + E.WOULD_CREATE_CYCLE, + f"adding edge {a.caller} -> {a.callee} would create an import cycle", + caller=a.caller, + callee=a.callee, + ) + return ActionResult.success(added_edge={"caller": a.caller, "callee": a.callee}) + + +def _h_remove_edge(graph: Graph, a: RemoveEdge) -> ActionResult: + e = graph.find_edge(a.caller, a.callee) + if e is None: + raise E.ActionError( + E.UNKNOWN_EDGE, + f"edge {a.caller} -> {a.callee} does not exist", + caller=a.caller, + callee=a.callee, + ) + graph.edges = [ + x for x in graph.edges if not (x.caller == a.caller and x.callee == a.callee) + ] + return ActionResult.success(removed_edge={"caller": a.caller, "callee": a.callee}) + + +# ---- info / terminal handlers (stubs) ------------------------------- + + +def _h_query_spec(graph: Graph, a: QuerySpec) -> ActionResult: + # TODO: route to graphforge.constraints once tasks/specs are wired in. + return ActionResult.success( + not_implemented="query_spec routed via dispatcher; constraint engine TODO", + constraint_kind=a.constraint_kind, + ) + + +def _h_query_subgraph(graph: Graph, a: QuerySubgraph) -> ActionResult: + scope = a.scope + if scope.startswith("module:"): + mod = scope[len("module:") :] + nodes = [n.model_dump() for n in graph.nodes_in_module(mod)] + edges = [ + e.model_dump() + for e in graph.edges + if e.caller.split(".")[0] == mod and e.callee.split(".")[0] == mod + ] + return ActionResult.success(scope=scope, nodes=nodes, edges=edges) + if scope.startswith("neighbors:"): + qn = scope[len("neighbors:") :] + return ActionResult.success( + scope=scope, + callers=graph.callers_of(qn), + callees=graph.callees_of(qn), + ) + if scope.startswith("path:"): + # TODO: shortest-path search over call graph. + return ActionResult.success( + scope=scope, not_implemented="path search TODO" + ) + raise E.ActionError(E.SCHEMA_REJECTION, f"unrecognized subgraph scope {scope!r}") + + +def _h_query_types(graph: Graph, a: QueryTypes) -> ActionResult: + # TODO: delegate to graphforge.types. + return ActionResult.success( + scope=a.scope, not_implemented="type engine TODO" + ) + + +def _h_materialize(graph: Graph, a: MaterializeAndValidate) -> ActionResult: + """Project the graph to source and run the parse-only validator gate. + + Heavier validation gates (mypy --strict, import-resolution, behavioral + tests) are added to this action's report as their subsystems land. + """ + from graphforge.materializer import materialize as _materialize + from graphforge.validator import full_check + + try: + files = _materialize(graph) + except ValueError as ve: + # Codegen rejected the graph (e.g. unknown pattern, template/edge + # structure mismatch missed by the dispatcher's preconditions). + raise E.ActionError( + E.SCHEMA_REJECTION, f"materialization failed: {ve}" + ) from ve + report = full_check(files) + return ActionResult.success( + files=list(files.keys()), + bytes_total=sum(len(s) for s in files.values()), + report=report.to_dict(), + ) + + +def _h_run_tests(graph: Graph, a: RunBehavioralTests) -> ActionResult: + # TODO: delegate to graphforge.behavioral. + raise E.ActionError( + E.SCHEMA_REJECTION, "run_behavioral_tests is not yet implemented" + ) + + +def _h_submit(graph: Graph, a: Submit) -> ActionResult: + return ActionResult.terminate(submitted=True) diff --git a/graphforge/actions/errors.py b/graphforge/actions/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..e499876f65202e64b92c8ab129fd6bc545067eb3 --- /dev/null +++ b/graphforge/actions/errors.py @@ -0,0 +1,44 @@ +"""Structured action errors. + +Every failure mode in the action dispatcher surfaces as an :class:`ActionError` +with a stable ``code`` so the agent can be trained against deterministic error +strings (see PROPOSAL.md ยง4.4 โ€” "failures return structured errors describing +the cause"). Codes are kept short and stable across versions. +""" + +from __future__ import annotations + +from typing import Any + + +class ActionError(Exception): + """Raised by action handlers; caught and reported by the dispatcher.""" + + def __init__(self, code: str, message: str, **details: Any) -> None: + super().__init__(f"[{code}] {message}") + self.code = code + self.message = message + self.details = details + + def to_dict(self) -> dict[str, Any]: + return {"error": self.code, "message": self.message, **self.details} + + +# ---- canonical codes ------------------------------------------------- +# Schema layer +SCHEMA_REJECTION = "schema_rejection" +# Pre-condition layer +UNKNOWN_MODULE = "unknown_module" +UNKNOWN_NODE = "unknown_node" +UNKNOWN_EDGE = "unknown_edge" +NAME_COLLISION = "name_collision" +MODULE_NOT_EMPTY = "module_not_empty" +NODE_HAS_REFERENCES = "node_has_references" +DUPLICATE_EDGE = "duplicate_edge" +UNKNOWN_TEMPLATE = "unknown_template" +TEMPLATE_ARGS_INVALID = "template_args_invalid" +RESPONSIBILITY_MISMATCH = "responsibility_mismatch" +ARG_MAPPING_INVALID = "arg_mapping_invalid" +# Post-condition layer +WOULD_CREATE_CYCLE = "would_create_cycle" +TYPE_MISMATCH = "type_mismatch" diff --git a/graphforge/actions/schema.py b/graphforge/actions/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc089c21ac800003bc817e159eaeac2ebb8dbb9 --- /dev/null +++ b/graphforge/actions/schema.py @@ -0,0 +1,180 @@ +"""Action message schemas. + +These are the wire shapes accepted by the dispatcher. Every action is a +discriminated-union member keyed on ``kind``. + +The action vocabulary mirrors PROPOSAL.md ยง4. Total surface: + + Graph mutations + add_module, remove_module + add_node, remove_node, set_node_module, attach_body + add_edge, remove_edge + Information + query_spec, query_subgraph, query_types, + materialize_and_validate, run_behavioral_tests + Terminal + submit + +Note: the proposal abstract states "eleven actions"; the section-4 listing +contains fourteen. We implement the section-4 set; the abstract count will +be corrected in the next revision of PROPOSAL.md. +""" + +from __future__ import annotations + +from typing import Annotated, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field + +from graphforge.graph.schema import ArgMapping, ErrorPolicy, Purity, ResponsibilityTag + + +# Common config: forbid unknown fields, fail loudly on schema drift. +_cfg = ConfigDict(extra="forbid") + + +# ---- mutations ------------------------------------------------------- + + +class AddModule(BaseModel): + model_config = _cfg + kind: Literal["add_module"] = "add_module" + name: str + responsibility: ResponsibilityTag + + +class RemoveModule(BaseModel): + model_config = _cfg + kind: Literal["remove_module"] = "remove_module" + name: str + + +class AddNode(BaseModel): + model_config = _cfg + kind: Literal["add_node"] = "add_node" + name: str + module: str + signature: str + purity: Purity = "impure" + error_policy: ErrorPolicy = "none" + + +class RemoveNode(BaseModel): + model_config = _cfg + kind: Literal["remove_node"] = "remove_node" + name: str + module: str + + +class SetNodeModule(BaseModel): + model_config = _cfg + kind: Literal["set_node_module"] = "set_node_module" + name: str + current_module: str + new_module: str + + +class AttachBody(BaseModel): + model_config = _cfg + kind: Literal["attach_body"] = "attach_body" + name: str + module: str + template: str + args: dict[str, object] = Field(default_factory=dict) + + +class AddEdge(BaseModel): + model_config = _cfg + kind: Literal["add_edge"] = "add_edge" + caller: str + callee: str + arg_mapping: list[ArgMapping] = Field(default_factory=list) + + +class RemoveEdge(BaseModel): + model_config = _cfg + kind: Literal["remove_edge"] = "remove_edge" + caller: str + callee: str + + +# ---- information actions -------------------------------------------- + + +class QuerySpec(BaseModel): + model_config = _cfg + kind: Literal["query_spec"] = "query_spec" + constraint_kind: Optional[str] = None + + +class QuerySubgraph(BaseModel): + model_config = _cfg + kind: Literal["query_subgraph"] = "query_subgraph" + scope: str # "module:" | "neighbors:" | "path::" + + +class QueryTypes(BaseModel): + model_config = _cfg + kind: Literal["query_types"] = "query_types" + scope: str # "all" | "module:" | "node:" + + +class MaterializeAndValidate(BaseModel): + model_config = _cfg + kind: Literal["materialize_and_validate"] = "materialize_and_validate" + + +class RunBehavioralTests(BaseModel): + model_config = _cfg + kind: Literal["run_behavioral_tests"] = "run_behavioral_tests" + materialized: bool = True + + +# ---- terminal -------------------------------------------------------- + + +class Submit(BaseModel): + model_config = _cfg + kind: Literal["submit"] = "submit" + + +# ---- discriminated union -------------------------------------------- + +Action = Annotated[ + Union[ + AddModule, + RemoveModule, + AddNode, + RemoveNode, + SetNodeModule, + AttachBody, + AddEdge, + RemoveEdge, + QuerySpec, + QuerySubgraph, + QueryTypes, + MaterializeAndValidate, + RunBehavioralTests, + Submit, + ], + Field(discriminator="kind"), +] + + +__all__ = [ + "Action", + "AddModule", + "RemoveModule", + "AddNode", + "RemoveNode", + "SetNodeModule", + "AttachBody", + "AddEdge", + "RemoveEdge", + "QuerySpec", + "QuerySubgraph", + "QueryTypes", + "MaterializeAndValidate", + "RunBehavioralTests", + "Submit", +] diff --git a/graphforge/actions/signature.py b/graphforge/actions/signature.py new file mode 100644 index 0000000000000000000000000000000000000000..f853a7a513c673e51958dbe2c39355614f25b394 --- /dev/null +++ b/graphforge/actions/signature.py @@ -0,0 +1,116 @@ +"""Cheap signature parser. + +Used by the dispatcher to validate ``add_edge`` arg-mappings against the +callee's parameter list. Real type flow validation (caller_arg type vs +callee_param type) is the type engine; this module only extracts parameter +*names* from a signature string of the form:: + + (a: int, b: str = "x", *, c: bool) -> bool + +Annotations are tolerated as opaque text. Defaults are tolerated and treated +as making the parameter optional. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Parameter: + name: str + annotation: str | None + has_default: bool + + +@dataclass(frozen=True) +class ParsedSignature: + parameters: list[Parameter] + return_annotation: str + + @property + def required_params(self) -> list[str]: + return [p.name for p in self.parameters if not p.has_default] + + @property + def all_params(self) -> list[str]: + return [p.name for p in self.parameters] + + +_SIG_RE = re.compile(r"^\s*\((?P.*)\)\s*->\s*(?P.+?)\s*$", re.DOTALL) + + +def parse_signature(sig: str) -> ParsedSignature: + """Parse a function signature string. Lenient โ€” caller validates more deeply. + + Raises ``ValueError`` on signatures that fail surface checks. The schema + layer (Node validator) already requires ``(`` and ``->``; this is the + secondary parse used at dispatch time. + """ + m = _SIG_RE.match(sig) + if not m: + raise ValueError(f"could not parse signature: {sig!r}") + raw_params = m.group("params").strip() + ret = m.group("ret").strip() + + params: list[Parameter] = [] + if raw_params: + for piece in _split_top_level(raw_params, ","): + piece = piece.strip() + if not piece or piece in {"*", "/"}: + continue + if piece.startswith("**"): + piece = piece[2:].lstrip() + elif piece.startswith("*"): + piece = piece[1:].lstrip() + has_default = False + if "=" in piece: + # split off default at top-level '=' (ignore ones inside [..]). + head, default = _split_default(piece) + piece = head.strip() + has_default = default is not None + name = piece + annotation: str | None = None + if ":" in piece: + name, annotation = piece.split(":", 1) + name = name.strip() + annotation = annotation.strip() + if not name.isidentifier(): + raise ValueError(f"unparseable parameter {piece!r} in {sig!r}") + params.append(Parameter(name=name, annotation=annotation, has_default=has_default)) + + return ParsedSignature(parameters=params, return_annotation=ret) + + +def _split_top_level(s: str, sep: str) -> list[str]: + """Split ``s`` on ``sep`` at bracket-depth 0.""" + out: list[str] = [] + depth = 0 + buf: list[str] = [] + for ch in s: + if ch in "([{": + depth += 1 + elif ch in ")]}": + depth -= 1 + if ch == sep and depth == 0: + out.append("".join(buf)) + buf = [] + else: + buf.append(ch) + if buf: + out.append("".join(buf)) + return out + + +def _split_default(piece: str) -> tuple[str, str | None]: + """Split off ``= default`` at bracket-depth 0. Returns (head, default | None).""" + depth = 0 + for i, ch in enumerate(piece): + if ch in "([{": + depth += 1 + elif ch in ")]}": + depth -= 1 + elif ch == "=" and depth == 0: + return piece[:i], piece[i + 1 :] + return piece, None diff --git a/graphforge/behavioral/__init__.py b/graphforge/behavioral/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..819cb3abdb861b54abf4592aade6f2ff083980e5 --- /dev/null +++ b/graphforge/behavioral/__init__.py @@ -0,0 +1,25 @@ +"""Behavioral test runner. + +Responsibilities (PROPOSAL.md ยง2.1, ยง6.2): + + * Run a property-based test suite (hypothesis) against materialized code, + in a sandboxed subprocess with timeout + memory limit. + * Tests are part of the task definition; their bodies are *hidden* from + the agent. The agent sees only test names and pass/fail at submission. + * Distinguish failures (assertion) from errors (timeout, crash) โ€” both + count as test failures, but they're surfaced separately for diagnostics. + +Public surface (TODO): + + run_tests(files, tests, timeout=12.0) -> dict[str, TestResult] +""" + +from __future__ import annotations + + +def run_tests( # pragma: no cover โ€” TODO + files: dict[str, str], + tests: list[object], + timeout: float = 12.0, +) -> dict[str, object]: + raise NotImplementedError("behavioral runner TODO โ€” see PROPOSAL.md ยง6.2") diff --git a/graphforge/constraints/__init__.py b/graphforge/constraints/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b65de419a05fa2b2772b7e83586f16e1acaf5e1 --- /dev/null +++ b/graphforge/constraints/__init__.py @@ -0,0 +1,49 @@ +"""Constraint vocabulary and dispatch. + +Three families (PROPOSAL.md ยง2.2): + + * Structural โ€” node_exists, edge_exists, module_count, acyclic_imports, + fan_in_max, fan_out_max, dag_depth_max, internal_only, โ€ฆ + * Type / signature โ€” signature_matches, return_type, arg_type, + type_consistency, no_any_types, pure_function (TODO) + * Behavioral / materialization โ€” materializes, imports_resolve, + type_checks, behavioral_test_passes, error_handling_present|absent + +Currently shipped: tier-0 subset of structural + ``materializes``. Additional +kinds land as new discriminated members in :mod:`schema` and matching +``_check_*`` functions in :mod:`checker`. +""" + +from graphforge.constraints.checker import ( + SatisfactionReport, + check, + evaluate_all, +) +from graphforge.constraints.schema import ( + AcyclicImports, + Constraint, + EdgeExists, + Materializes, + ModuleCount, + ModuleResponsibility, + ModuleSizeMax, + NodeAbsent, + NodeExists, + STRUCTURAL_KINDS, +) + +__all__ = [ + "AcyclicImports", + "Constraint", + "EdgeExists", + "Materializes", + "ModuleCount", + "ModuleResponsibility", + "ModuleSizeMax", + "NodeAbsent", + "NodeExists", + "STRUCTURAL_KINDS", + "SatisfactionReport", + "check", + "evaluate_all", +] diff --git a/graphforge/constraints/checker.py b/graphforge/constraints/checker.py new file mode 100644 index 0000000000000000000000000000000000000000..8a54277b5420b867ee0a2110e7793ea20ef0ee8b --- /dev/null +++ b/graphforge/constraints/checker.py @@ -0,0 +1,141 @@ +"""Constraint checker dispatch. + +Each constraint kind has a small ``_check_*`` function. ``check`` routes by +isinstance and ``evaluate_all`` reports which constraints from a list are +satisfied or not. + +Behavioral / materialization constraints (currently just ``materializes``) +delegate to the materializer and validator subsystems. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from graphforge.constraints.schema import ( + AcyclicImports, + Constraint, + EdgeExists, + Materializes, + ModuleCount, + ModuleResponsibility, + ModuleSizeMax, + NodeAbsent, + NodeExists, + STRUCTURAL_KINDS, +) +from graphforge.graph.schema import Graph + + +@dataclass +class SatisfactionReport: + satisfied: list[Constraint] = field(default_factory=list) + unsatisfied: list[Constraint] = field(default_factory=list) + + @property + def total(self) -> int: + return len(self.satisfied) + len(self.unsatisfied) + + @property + def all_satisfied(self) -> bool: + return self.total > 0 and not self.unsatisfied + + def split_by_family(self) -> tuple["SatisfactionReport", "SatisfactionReport"]: + """Split into (structural, behavioral) sub-reports. + + Useful for the reward engine, which scores the two families with + different magnitudes per PROPOSAL.md ยง5.2. + """ + sr = SatisfactionReport() + br = SatisfactionReport() + for c in self.satisfied: + (sr if c.kind in STRUCTURAL_KINDS else br).satisfied.append(c) + for c in self.unsatisfied: + (sr if c.kind in STRUCTURAL_KINDS else br).unsatisfied.append(c) + return sr, br + + def to_dict(self) -> dict[str, object]: + return { + "satisfied": [c.model_dump() for c in self.satisfied], + "unsatisfied": [c.model_dump() for c in self.unsatisfied], + "total": self.total, + "all_satisfied": self.all_satisfied, + } + + +# ---- per-kind checkers ---------------------------------------------- + + +def _check_node_exists(g: Graph, c: NodeExists) -> bool: + return g.find_node(c.name, c.module) is not None + + +def _check_node_absent(g: Graph, c: NodeAbsent) -> bool: + return g.find_node(c.name, c.module) is None + + +def _check_edge_exists(g: Graph, c: EdgeExists) -> bool: + return g.find_edge(c.caller, c.callee) is not None + + +def _check_module_count(g: Graph, c: ModuleCount) -> bool: + return len(g.modules) == c.n + + +def _check_module_size_max(g: Graph, c: ModuleSizeMax) -> bool: + return len(g.nodes_in_module(c.module)) <= c.n + + +def _check_module_responsibility(g: Graph, c: ModuleResponsibility) -> bool: + m = g.find_module(c.module) + return m is not None and m.responsibility == c.responsibility + + +def _check_acyclic_imports(g: Graph, _c: AcyclicImports) -> bool: + return not g.has_module_cycle() + + +def _check_materializes(g: Graph, _c: Materializes) -> bool: + # Imported lazily so that callers who don't use this checker don't pay + # the cost of pulling the materializer/validator graph. + from graphforge.materializer import materialize + from graphforge.validator import full_check + + try: + files = materialize(g) + except Exception: + return False + return full_check(files).ok + + +# ---- dispatch -------------------------------------------------------- + + +def check(graph: Graph, constraint: Constraint) -> bool: + if isinstance(constraint, NodeExists): + return _check_node_exists(graph, constraint) + if isinstance(constraint, NodeAbsent): + return _check_node_absent(graph, constraint) + if isinstance(constraint, EdgeExists): + return _check_edge_exists(graph, constraint) + if isinstance(constraint, ModuleCount): + return _check_module_count(graph, constraint) + if isinstance(constraint, ModuleSizeMax): + return _check_module_size_max(graph, constraint) + if isinstance(constraint, ModuleResponsibility): + return _check_module_responsibility(graph, constraint) + if isinstance(constraint, AcyclicImports): + return _check_acyclic_imports(graph, constraint) + if isinstance(constraint, Materializes): + return _check_materializes(graph, constraint) + raise ValueError(f"unknown constraint kind: {constraint!r}") + + +def evaluate_all(graph: Graph, constraints: list[Constraint]) -> SatisfactionReport: + rep = SatisfactionReport() + for c in constraints: + if check(graph, c): + rep.satisfied.append(c) + else: + rep.unsatisfied.append(c) + return rep diff --git a/graphforge/constraints/schema.py b/graphforge/constraints/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..05fbbd38ce8e8070827582326a9f26d498ac2a4f --- /dev/null +++ b/graphforge/constraints/schema.py @@ -0,0 +1,129 @@ +"""Constraint schemas (tier-0 subset). + +Constraints are pydantic discriminated-union members keyed on ``kind``. +Tier-0 carves out the smallest set sufficient to express a real task and +exercise the reward engine end-to-end. The remaining vocabulary in +PROPOSAL.md ยง2.2 (fan_in_max, dag_depth_max, type_consistency, +behavioral_test_passes, โ€ฆ) lands on top of this same shape as new +discriminated members + checker functions. + +Each constraint member is a pure data record. Behavior lives in +:mod:`graphforge.constraints.checker`. +""" + +from __future__ import annotations + +from typing import Annotated, Literal, Union + +from pydantic import BaseModel, ConfigDict, Field + +from graphforge.graph.schema import ResponsibilityTag + +_cfg = ConfigDict(extra="forbid") + + +# ---- structural ------------------------------------------------------ + + +class NodeExists(BaseModel): + model_config = _cfg + kind: Literal["node_exists"] = "node_exists" + name: str + module: str + + +class NodeAbsent(BaseModel): + model_config = _cfg + kind: Literal["node_absent"] = "node_absent" + name: str + module: str + + +class EdgeExists(BaseModel): + model_config = _cfg + kind: Literal["edge_exists"] = "edge_exists" + caller: str # qualified + callee: str # qualified + + +class ModuleCount(BaseModel): + model_config = _cfg + kind: Literal["module_count"] = "module_count" + n: int = Field(..., ge=0) + + +class ModuleSizeMax(BaseModel): + model_config = _cfg + kind: Literal["module_size_max"] = "module_size_max" + module: str + n: int = Field(..., ge=0) + + +class ModuleResponsibility(BaseModel): + model_config = _cfg + kind: Literal["module_responsibility"] = "module_responsibility" + module: str + responsibility: ResponsibilityTag + + +class AcyclicImports(BaseModel): + model_config = _cfg + kind: Literal["acyclic_imports"] = "acyclic_imports" + + +# ---- behavioral / materialization ----------------------------------- + + +class Materializes(BaseModel): + model_config = _cfg + kind: Literal["materializes"] = "materializes" + + +# ---- discriminated union -------------------------------------------- + +Constraint = Annotated[ + Union[ + NodeExists, + NodeAbsent, + EdgeExists, + ModuleCount, + ModuleSizeMax, + ModuleResponsibility, + AcyclicImports, + Materializes, + ], + Field(discriminator="kind"), +] + + +# Set of kinds considered "structural" for the reward engine's per-constraint +# +1 magnitude. The "behavioral" family is reserved for property-test results +# (BehavioralTestPasses, TODO) which earn the higher +3 magnitude. The +# ``materializes`` constraint is structural for scoring purposes; the more +# severe "Materialization fails: -8" penalty in PROPOSAL.md ยง5.2 is an +# independent gate driven by the materializer raising or returning parse +# errors, not by this constraint kind. +STRUCTURAL_KINDS = { + "node_exists", + "node_absent", + "edge_exists", + "module_count", + "module_size_max", + "module_responsibility", + "acyclic_imports", + "materializes", +} + + +__all__ = [ + "AcyclicImports", + "Constraint", + "EdgeExists", + "Materializes", + "ModuleCount", + "ModuleResponsibility", + "ModuleSizeMax", + "NodeAbsent", + "NodeExists", + "STRUCTURAL_KINDS", +] diff --git a/graphforge/graph/__init__.py b/graphforge/graph/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92a954ba19fa90da3c80d1d6f5df43cb19d6ad07 --- /dev/null +++ b/graphforge/graph/__init__.py @@ -0,0 +1,23 @@ +"""Canonical graph schema. See :mod:`graphforge.graph.schema`.""" + +from graphforge.graph.schema import ( + ArgMapping, + Edge, + ErrorPolicy, + Graph, + Module, + Node, + Purity, + ResponsibilityTag, +) + +__all__ = [ + "ArgMapping", + "Edge", + "ErrorPolicy", + "Graph", + "Module", + "Node", + "Purity", + "ResponsibilityTag", +] diff --git a/graphforge/graph/schema.py b/graphforge/graph/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..c2bc176073c477a5857979f96573a7b2a9165482 --- /dev/null +++ b/graphforge/graph/schema.py @@ -0,0 +1,308 @@ +"""Canonical graph schema. + +The graph is the single source of truth for an in-progress program. Every +materialization is a deterministic function of (graph, template library). + +Wire format mirrors the JSON shape documented in PROPOSAL.md ยง3.1, exactly: + + { + "modules": [{"name": ..., "responsibility": ...}, ...], + "nodes": [{"name": ..., "module": ..., "signature": ..., + "body_template": ..., "body_template_args": {...}, + "purity": ..., "error_policy": ..., "decl_order": ...}, ...], + "edges": [{"caller": ".", + "callee": ".", + "arg_mapping": [{"caller_arg": ..., "callee_param": ...}, ...]}, ...] + } + +This module enforces shape and well-formedness only. Higher-order invariants +(unique names, edge endpoints exist, no cycles, type-flow compatibility) are +enforced by the action dispatcher and the type engine, not the schema, so +that callers can build partial / invalid graphs and inspect why they fail. +""" + +from __future__ import annotations + +import hashlib +import json +from typing import Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +# ---------------------------------------------------------------------- +# Enumerated tags +# ---------------------------------------------------------------------- + +# Responsibility tags constrain which kinds of nodes a module is allowed to +# host. The canonical set; new tags are added intentionally because tasks +# encode constraints against this vocabulary. +ResponsibilityTag = Literal[ + "io", + "validation", + "transform", + "orchestration", + "storage", + "formatting", + "lookup", + "policy", + "logging", + "computation", +] + +Purity = Literal["pure", "impure"] + +# How a function handles errors in its body. "guard" means it includes a +# guard / try-except. "propagate" means it deliberately lets errors flow up. +# "none" is the default โ€” no claim either way. +ErrorPolicy = Literal["guard", "propagate", "none"] + + +# ---------------------------------------------------------------------- +# Atomic records +# ---------------------------------------------------------------------- + + +class Module(BaseModel): + """A declared module โ€” one Python file at materialization time.""" + + model_config = ConfigDict(extra="forbid", frozen=False) + + name: str = Field(..., min_length=1) + responsibility: ResponsibilityTag + + @field_validator("name") + @classmethod + def _name_is_identifier(cls, v: str) -> str: + if not v.isidentifier(): + raise ValueError(f"module name {v!r} is not a Python identifier") + if v.startswith("_"): + raise ValueError(f"module name {v!r} must not start with an underscore") + return v + + +class Node(BaseModel): + """A declared function. ``body_template`` may be unset until attach_body.""" + + model_config = ConfigDict(extra="forbid", frozen=False) + + name: str = Field(..., min_length=1) + module: str = Field(..., min_length=1) + signature: str = Field(..., min_length=2) # e.g., "(x: int) -> bool" + body_template: Optional[str] = None + body_template_args: dict[str, object] = Field(default_factory=dict) + purity: Purity = "impure" + error_policy: ErrorPolicy = "none" + decl_order: int = 0 + + @field_validator("name") + @classmethod + def _name_is_identifier(cls, v: str) -> str: + if not v.isidentifier(): + raise ValueError(f"node name {v!r} is not a Python identifier") + return v + + @field_validator("signature") + @classmethod + def _signature_shape(cls, v: str) -> str: + # Cheap surface check; the type engine does the real parse. + if not v.lstrip().startswith("("): + raise ValueError(f"signature must start with '(': got {v!r}") + if "->" not in v: + raise ValueError(f"signature must include '->' return arrow: got {v!r}") + return v + + # Convenience ----------------------------------------------------- + + @property + def qualified_name(self) -> str: + """``.`` โ€” the canonical address used on edges.""" + return f"{self.module}.{self.name}" + + +class ArgMapping(BaseModel): + """How an edge wires a caller's argument to a callee's parameter.""" + + model_config = ConfigDict(extra="forbid", frozen=False) + + caller_arg: str = Field(..., min_length=1) + callee_param: str = Field(..., min_length=1) + + +class Edge(BaseModel): + """A CALLS edge. Endpoints are qualified node names ``.``.""" + + model_config = ConfigDict(extra="forbid", frozen=False) + + caller: str = Field(..., min_length=3) + callee: str = Field(..., min_length=3) + arg_mapping: list[ArgMapping] = Field(default_factory=list) + + @field_validator("caller", "callee") + @classmethod + def _qualified(cls, v: str) -> str: + if v.count(".") != 1: + raise ValueError( + f"edge endpoint {v!r} is not qualified (expected '.')" + ) + mod, name = v.split(".") + if not mod.isidentifier() or not name.isidentifier(): + raise ValueError(f"edge endpoint {v!r} has non-identifier parts") + return v + + +# ---------------------------------------------------------------------- +# Graph +# ---------------------------------------------------------------------- + + +class Graph(BaseModel): + """Canonical graph state. Mutable; cloned via ``snapshot``/``restore``.""" + + model_config = ConfigDict(extra="forbid", frozen=False) + + modules: list[Module] = Field(default_factory=list) + nodes: list[Node] = Field(default_factory=list) + edges: list[Edge] = Field(default_factory=list) + + # ----- lookup ---------------------------------------------------- + + def find_module(self, name: str) -> Optional[Module]: + for m in self.modules: + if m.name == name: + return m + return None + + def find_node(self, name: str, module: str) -> Optional[Node]: + for n in self.nodes: + if n.name == name and n.module == module: + return n + return None + + def find_node_qualified(self, qualified: str) -> Optional[Node]: + if qualified.count(".") != 1: + return None + mod, nm = qualified.split(".") + return self.find_node(nm, mod) + + def find_edge(self, caller: str, callee: str) -> Optional[Edge]: + for e in self.edges: + if e.caller == caller and e.callee == callee: + return e + return None + + def nodes_in_module(self, module: str) -> list[Node]: + return [n for n in self.nodes if n.module == module] + + def callers_of(self, qualified: str) -> list[str]: + return [e.caller for e in self.edges if e.callee == qualified] + + def callees_of(self, qualified: str) -> list[str]: + return [e.callee for e in self.edges if e.caller == qualified] + + def fan_in(self, qualified: str) -> int: + return len(self.callers_of(qualified)) + + def fan_out(self, qualified: str) -> int: + return len(self.callees_of(qualified)) + + # ----- structural derivations ------------------------------------ + + def import_edges(self) -> set[tuple[str, str]]: + """Set of (caller_module, callee_module) pairs from cross-module edges.""" + out: set[tuple[str, str]] = set() + for e in self.edges: + cm = e.caller.split(".")[0] + tm = e.callee.split(".")[0] + if cm != tm: + out.add((cm, tm)) + return out + + def has_module_cycle(self) -> bool: + """True iff the cross-module import graph contains a directed cycle.""" + adj: dict[str, set[str]] = {m.name: set() for m in self.modules} + for src, dst in self.import_edges(): + adj.setdefault(src, set()).add(dst) + adj.setdefault(dst, set()) + WHITE, GRAY, BLACK = 0, 1, 2 + color: dict[str, int] = {k: WHITE for k in adj} + + def visit(u: str) -> bool: + color[u] = GRAY + for v in adj.get(u, ()): + if color[v] == GRAY: + return True + if color[v] == WHITE and visit(v): + return True + color[u] = BLACK + return False + + return any(color[u] == WHITE and visit(u) for u in adj) + + def call_graph_depth(self) -> int: + """Longest path length (in edges) in the function call DAG. + + If the call graph is cyclic, returns the special value -1 (callers + should treat this as an invariant violation). + """ + adj: dict[str, list[str]] = {n.qualified_name: [] for n in self.nodes} + for e in self.edges: + adj.setdefault(e.caller, []).append(e.callee) + adj.setdefault(e.callee, []) + memo: dict[str, int] = {} + ON_STACK = -2 + + def dfs(u: str) -> int: + if u in memo: + if memo[u] == ON_STACK: + return -1 + return memo[u] + memo[u] = ON_STACK + best = 0 + for v in adj.get(u, ()): + d = dfs(v) + if d == -1: + return -1 + best = max(best, d + 1) + memo[u] = best + return best + + results = [dfs(u) for u in adj] + if any(r == -1 for r in results): + return -1 + return max(results, default=0) + + # ----- copying / hashing ----------------------------------------- + + def snapshot(self) -> "Graph": + """Deep copy. Used by the dispatcher for atomic action rollback.""" + return self.model_copy(deep=True) + + def structural_hash(self) -> str: + """Stable SHA-256 over a canonical JSON projection. + + Insensitive to list ordering on the dimensions where order is not + semantically meaningful (modules, nodes), but sensitive to + ``decl_order`` because that affects materialized output. + """ + canon: dict[str, object] = { + "modules": sorted( + [m.model_dump() for m in self.modules], + key=lambda d: d["name"], + ), + "nodes": sorted( + [n.model_dump() for n in self.nodes], + key=lambda d: (d["module"], d["name"]), + ), + "edges": sorted( + [e.model_dump() for e in self.edges], + key=lambda d: (d["caller"], d["callee"]), + ), + } + blob = json.dumps(canon, sort_keys=True, default=str).encode("utf-8") + return hashlib.sha256(blob).hexdigest() + + # ----- factories ------------------------------------------------- + + @classmethod + def empty(cls) -> "Graph": + return cls() diff --git a/graphforge/knowledge_graph.py b/graphforge/knowledge_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..bcf3de8b9fccc3b7160d650c7264aed5fb4f6cef --- /dev/null +++ b/graphforge/knowledge_graph.py @@ -0,0 +1,233 @@ +"""In-memory Knowledge Graph for a Python repository. + +Mirrors the structure of a Neo4j property graph but lives in RAM: + +Nodes +----- + repo โ€” the repository root + package โ€” a directory containing __init__.py + module โ€” a .py file + class โ€” a class definition + function โ€” a top-level or nested function / async function + method โ€” a method inside a class + +Edges (directed) +----------------- + contains โ€” parent โ†’ child (repoโ†’package, packageโ†’module, moduleโ†’class, โ€ฆ) + calls โ€” function/method โ†’ function/method (same-file same-package) + imports โ€” module โ†’ module (from x import y / import x) + inherits โ€” class โ†’ class + +Each node stores the actual source lines so the agent can read/edit them. +""" + +from __future__ import annotations + +import textwrap +from dataclasses import dataclass, field +from typing import Iterable + + +# โ”€โ”€ node & edge โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +@dataclass +class KGNode: + node_id: str # unique key, e.g. "function:validators.py:validate_title" + node_type: str # module | class | function | method | package | repo + name: str # short identifier + file_path: str # relative path from repo root (empty for repo/package) + line_start: int = 0 + line_end: int = 0 + source: str = "" # full source text of this node (incl. def line) + docstring: str = "" + metadata: dict = field(default_factory=dict) + + def brief(self) -> str: + """One-line summary for graph overviews.""" + loc = f" [{self.file_path}:{self.line_start}]" if self.file_path else "" + return f"[{self.node_type.upper():<8}] {self.node_id}{loc}" + + +@dataclass +class KGEdge: + edge_type: str # contains | calls | imports | inherits + source_id: str + target_id: str + + +# โ”€โ”€ knowledge graph โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class KnowledgeGraph: + """Property graph for a repository. + + Supports rich queries used by the agent and reward checker. + """ + + def __init__(self, repo_path: str) -> None: + self.repo_path = repo_path + self._nodes: dict[str, KGNode] = {} + self._edges: list[KGEdge] = [] + + # โ”€โ”€ mutation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + def add_node(self, node: KGNode) -> None: + self._nodes[node.node_id] = node + + def add_edge(self, edge: KGEdge) -> None: + self._edges.append(edge) + + def update_node_source(self, node_id: str, new_source: str) -> None: + """Replace a node's source and recount lines.""" + node = self._nodes[node_id] + node.source = new_source + lines = new_source.splitlines() + node.line_end = node.line_start + len(lines) - 1 + + def insert_node( + self, + parent_id: str, + new_node: KGNode, + ) -> None: + """Add new_node to the graph and wire a contains edge from parent.""" + self._nodes[new_node.node_id] = new_node + self._edges.append(KGEdge("contains", parent_id, new_node.node_id)) + + def remove_node(self, node_id: str) -> None: + self._nodes.pop(node_id, None) + self._edges = [e for e in self._edges + if e.source_id != node_id and e.target_id != node_id] + + # โ”€โ”€ queries โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + def get_node(self, node_id: str) -> KGNode | None: + return self._nodes.get(node_id) + + def all_nodes(self, node_type: str | None = None) -> list[KGNode]: + nodes = list(self._nodes.values()) + if node_type: + nodes = [n for n in nodes if n.node_type == node_type] + return nodes + + def children_of(self, node_id: str) -> list[KGNode]: + child_ids = {e.target_id for e in self._edges + if e.source_id == node_id and e.edge_type == "contains"} + return [self._nodes[cid] for cid in child_ids if cid in self._nodes] + + def parent_of(self, node_id: str) -> KGNode | None: + for e in self._edges: + if e.target_id == node_id and e.edge_type == "contains": + return self._nodes.get(e.source_id) + return None + + def callers_of(self, node_id: str) -> list[KGNode]: + caller_ids = {e.source_id for e in self._edges + if e.target_id == node_id and e.edge_type == "calls"} + return [self._nodes[cid] for cid in caller_ids if cid in self._nodes] + + def callees_of(self, node_id: str) -> list[KGNode]: + callee_ids = {e.target_id for e in self._edges + if e.source_id == node_id and e.edge_type == "calls"} + return [self._nodes[cid] for cid in callee_ids if cid in self._nodes] + + def imports_of(self, module_id: str) -> list[KGNode]: + imp_ids = {e.target_id for e in self._edges + if e.source_id == module_id and e.edge_type == "imports"} + return [self._nodes[i] for i in imp_ids if i in self._nodes] + + def search(self, keywords: str, node_type: str | None = None) -> list[KGNode]: + """Fuzzy keyword search over node names, docstrings, and source.""" + kws = keywords.lower().split() + results: list[KGNode] = [] + for node in self._nodes.values(): + if node_type and node.node_type != node_type: + continue + haystack = f"{node.name} {node.docstring} {node.source}".lower() + if all(kw in haystack for kw in kws): + results.append(node) + return results + + def subgraph(self, root_id: str, depth: int = 2) -> list[KGNode]: + """BFS from root_id up to depth hops; returns all encountered nodes.""" + visited: set[str] = set() + frontier = {root_id} + for _ in range(depth): + next_frontier: set[str] = set() + for nid in frontier: + if nid in visited: + continue + visited.add(nid) + for e in self._edges: + if e.source_id == nid and e.target_id not in visited: + next_frontier.add(e.target_id) + frontier = next_frontier + visited.update(frontier) + return [self._nodes[nid] for nid in visited if nid in self._nodes] + + # โ”€โ”€ text representations โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + def overview(self, max_chars: int = 3000) -> str: + """Compact multi-line overview of the repo graph, capped to avoid LLM context overflow.""" + lines: list[str] = [f"## Repository: {self.repo_path}", ""] + modules = self.all_nodes("module") + all_fns = self.all_nodes("function") + all_cls = self.all_nodes("class") + lines.append(f" {len(modules)} modules ยท {len(all_fns)} functions ยท {len(all_cls)} classes") + lines.append("") + + for mod in sorted(modules, key=lambda n: n.file_path): + children = self.children_of(mod.node_id) + funcs = [c for c in children if c.node_type in ("function", "method")] + classes = [c for c in children if c.node_type == "class"] + summary = [] + if classes: + summary.append(f"{len(classes)} class{'es' if len(classes)>1 else ''}") + if funcs: + summary.append(f"{len(funcs)} fn{'s' if len(funcs)>1 else ''}") + lines.append(f" [{mod.file_path}] ({', '.join(summary) or 'empty'})") + for cls in sorted(classes, key=lambda n: n.name): + methods = [c for c in self.children_of(cls.node_id) if c.node_type == "method"] + mnames = ", ".join(m.name for m in sorted(methods, key=lambda n: n.line_start)) + lines.append(f" class {cls.name} โ†’ {mnames or '(no methods)'}") + lines.append(f" node_id: {cls.node_id}") + for fn in sorted(funcs, key=lambda n: n.line_start): + lines.append(f" def {fn.name}{fn.metadata.get('signature', '')}") + lines.append(f" node_id: {fn.node_id}") + + # Stop expanding if we are already near the character cap + current = "\n".join(lines) + if len(current) > max_chars: + remaining = len(modules) - (modules.index(mod) + 1) + if remaining: + lines.append(f"\n ... [{remaining} more modules not shown โ€” use query() to explore]") + break + + return "\n".join(lines) + + def node_detail(self, node_id: str) -> str: + """Full inspection view of a single node.""" + node = self._nodes.get(node_id) + if node is None: + return f"[ERROR] node_id {node_id!r} not found in graph." + lines = [ + f"## Node: {node.node_id}", + f"type : {node.node_type}", + f"file : {node.file_path} (lines {node.line_start}โ€“{node.line_end})", + ] + if node.docstring: + lines.append(f"docstring: {node.docstring[:120]}") + callers = self.callers_of(node_id) + callees = self.callees_of(node_id) + if callers: + lines.append("called by: " + ", ".join(n.name for n in callers)) + if callees: + lines.append("calls : " + ", ".join(n.name for n in callees)) + children = self.children_of(node_id) + if children: + lines.append("contains : " + ", ".join(c.name for c in children)) + lines += ["", "### Source", "```python", node.source or "(no source)", "```"] + return "\n".join(lines) + + def snapshot(self) -> "KnowledgeGraph": + """Deep copy โ€” used to preserve state before mutations.""" + import copy + return copy.deepcopy(self) diff --git a/graphforge/materializer/__init__.py b/graphforge/materializer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d568b7de75e53c047058946a163cf90d6b40d80 --- /dev/null +++ b/graphforge/materializer/__init__.py @@ -0,0 +1,20 @@ +"""Graph -> Python source projection. + +Responsibilities (PROPOSAL.md ยง3.3): + + * Emit one ``.py`` per declared module. + * Emit functions in :attr:`Node.decl_order` order. + * Compute ``from import `` lines from cross-module edges, + deduplicated and sorted. + * Expand body templates with the node's ``body_template_args`` to produce + a runnable function body. + +The materializer is total over well-formed graphs: every dispatcher-accepted +graph must produce parseable source. Round-trip correctness (the produced +source re-parses to the same graph) is enforced by tests in +:mod:`graphforge.parser` (TODO). +""" + +from graphforge.materializer.materialize import materialize + +__all__ = ["materialize"] diff --git a/graphforge/materializer/codegen.py b/graphforge/materializer/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..b83e9b64b61cfa6751fbe8487cd8b2b479c25102 --- /dev/null +++ b/graphforge/materializer/codegen.py @@ -0,0 +1,169 @@ +"""Per-template body codegen. + +Each public ``render_