Spaces:
Sleeping
Sleeping
File size: 8,297 Bytes
7952f32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | """Auto-generate training tasks from any Python repository.
Pipeline
--------
1. Parse the repo with AST β KnowledgeGraph
2. Find public functions that have doctest examples (>>> in docstring)
3. Extract those examples as runnable assertions
4. Replace the function body with `raise NotImplementedError` β the agent
must re-implement it from the docstring alone
5. Return RepoTask objects ready for GRPO training β no hand-writing needed
Usage
-----
from graphforge.task_generator import generate_tasks
tasks = generate_tasks("/tmp/humanize/src/humanize", n_tasks=6)
for t in tasks:
print(t.task_id, "β", t.description[:60])
"""
from __future__ import annotations
import ast
import doctest
import textwrap
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from graphforge.knowledge_graph import KGNode, KnowledgeGraph
from graphforge.repo_parser import parse_repo
# ββ Task dataclass (mirrors env.tasks.RepoTask but lives here to avoid circular import) ββ
@dataclass
class AutoTask:
task_id: str
repo_name: str
repo_path: str # absolute path to the repo source directory
description: str
test_code: str # uses short import: from <repo_name>.<module> import <func>
stubbed_node_id: str # the node whose body was replaced
original_source: str # saved so env can restore on reset
max_turns: int = 12
difficulty: int = 0
hints: list[str] = field(default_factory=list)
# ββ Doctest extraction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _extract_all_examples(docstring: str) -> list[tuple[str, str]]:
"""Return ALL doctest lines as (source, want) β want is '' for setup lines."""
if not docstring:
return []
parser = doctest.DocTestParser()
try:
examples = parser.get_examples(docstring, name="<doc>")
return [(ex.source.strip(), ex.want.strip()) for ex in examples]
except Exception:
return []
def _to_assertion(expr: str, expected: str) -> str | None:
"""Convert one doctest example to a Python assertion.
- True/False expected β assert (expr) is True/False
- Traceback expected β skip
- Non-literal β skip
"""
if not expected or expected.startswith("Traceback"):
return None
if expected in ("True", "False"):
return f"assert ({expr}) is {expected}, f'got {{repr({expr})}}'"
try:
ast.literal_eval(expected)
except (ValueError, SyntaxError):
return None
return f"assert {expr} == {expected}, f'got {{repr({expr})}}'"
def _build_test_code(func_name: str, module_stem: str, repo_name: str,
all_examples: list[tuple[str, str]]) -> str | None:
"""Build complete test code including setup lines then assertions."""
import_line = f"from {repo_name}.{module_stem} import {func_name}"
setup_lines: list[str] = []
assertion_lines: list[str] = []
for expr, expected in all_examples:
if not expected:
setup_lines.append(expr)
else:
a = _to_assertion(expr, expected)
if a and func_name in a: # only keep assertions that call our function
assertion_lines.append(a)
if len(assertion_lines) < 2:
return None
parts = [import_line] + setup_lines + assertion_lines
return "\n".join(parts)
# ββ Function stubbing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _stub_function(source: str) -> str:
"""Replace a function body with `raise NotImplementedError`, keeping signature + docstring."""
dedented = textwrap.dedent(source)
try:
tree = ast.parse(dedented)
except SyntaxError:
return source
lines = dedented.splitlines()
for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
body = node.body
indent = " " * (node.col_offset // 4 + 1)
# Keep signature lines (everything up to and including the colon)
sig_end = body[0].lineno - 1 # 0-indexed line where body starts
# Keep docstring if present
if body and isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant):
keep_until = body[0].end_lineno # inclusive, 1-indexed
else:
keep_until = sig_end
kept = "\n".join(lines[:keep_until])
stub = kept.rstrip() + f"\n{indent}raise NotImplementedError\n"
return stub
return source
# ββ Candidate selection βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _score_candidate(node: KGNode, examples: list) -> int:
"""Higher = better training signal. Prefer more examples and longer docstrings."""
return len(examples) * 3 + min(len(node.docstring or ""), 200) // 20
def _find_candidates(kg: KnowledgeGraph, repo_name: str) -> list[tuple[KGNode, str, int]]:
"""Return (node, test_code, score) for all viable candidates."""
candidates = []
for node in kg.all_nodes("function"):
if node.name.startswith("_"):
continue
if not node.docstring or not node.source:
continue
module_stem = Path(node.file_path).stem if node.file_path else None
if not module_stem:
continue
examples = _extract_all_examples(node.docstring)
if not examples:
continue
test_code = _build_test_code(node.name, module_stem, repo_name, examples)
if not test_code:
continue
score = _score_candidate(node, examples)
candidates.append((node, test_code, score))
candidates.sort(key=lambda x: x[2], reverse=True)
return candidates
# ββ Main entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def generate_tasks(
repo_source_dir: str,
n_tasks: int = 4,
max_turns: int = 12,
) -> tuple[KnowledgeGraph, list[AutoTask]]:
"""Parse a Python repo directory and auto-generate training tasks.
Args:
repo_source_dir: Path to the Python package source directory.
e.g. '/tmp/humanize/src/humanize'
n_tasks: How many tasks to generate (picks highest-scoring candidates).
max_turns: Max turns per episode.
Returns:
(kg, tasks) β the Knowledge Graph and the list of AutoTask objects.
"""
repo_source_dir = str(Path(repo_source_dir).resolve())
repo_name = Path(repo_source_dir).name
kg = parse_repo(repo_source_dir)
candidates = _find_candidates(kg, repo_name)
if not candidates:
raise ValueError(
f"No suitable candidates found in {repo_source_dir}. "
"Make sure functions have doctest examples (>>> in docstring)."
)
selected = candidates[:n_tasks]
tasks: list[AutoTask] = []
for node, test_code, score in selected:
stubbed = _stub_function(node.source)
desc = textwrap.dedent(f"""\
Implement the function `{node.name}` in `{node.file_path}`.
{node.docstring.strip() if node.docstring else 'No docstring available.'}
""").strip()
task = AutoTask(
task_id=f"auto.{repo_name}.{node.name}",
repo_name=repo_name,
repo_path=repo_source_dir,
description=desc,
test_code=test_code,
stubbed_node_id=node.node_id,
original_source=node.source,
max_turns=max_turns,
difficulty=min(2, max(0, score // 8)),
hints=[
f"Look at {node.file_path} to understand the module style.",
f"The function signature is: {node.name}{node.metadata.get('signature', '(...)')}",
],
)
tasks.append(task)
return kg, tasks
|