File size: 8,297 Bytes
7952f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""Auto-generate training tasks from any Python repository.

Pipeline
--------
1. Parse the repo with AST β†’ KnowledgeGraph
2. Find public functions that have doctest examples (>>> in docstring)
3. Extract those examples as runnable assertions
4. Replace the function body with `raise NotImplementedError` β€” the agent
   must re-implement it from the docstring alone
5. Return RepoTask objects ready for GRPO training β€” no hand-writing needed

Usage
-----
    from graphforge.task_generator import generate_tasks
    tasks = generate_tasks("/tmp/humanize/src/humanize", n_tasks=6)
    for t in tasks:
        print(t.task_id, "β†’", t.description[:60])
"""

from __future__ import annotations

import ast
import doctest
import textwrap
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

from graphforge.knowledge_graph import KGNode, KnowledgeGraph
from graphforge.repo_parser import parse_repo


# ── Task dataclass (mirrors env.tasks.RepoTask but lives here to avoid circular import) ──

@dataclass
class AutoTask:
    task_id: str
    repo_name: str
    repo_path: str          # absolute path to the repo source directory
    description: str
    test_code: str          # uses short import: from <repo_name>.<module> import <func>
    stubbed_node_id: str    # the node whose body was replaced
    original_source: str    # saved so env can restore on reset
    max_turns: int = 12
    difficulty: int = 0
    hints: list[str] = field(default_factory=list)


# ── Doctest extraction ────────────────────────────────────────────────────────

def _extract_all_examples(docstring: str) -> list[tuple[str, str]]:
    """Return ALL doctest lines as (source, want) β€” want is '' for setup lines."""
    if not docstring:
        return []
    parser = doctest.DocTestParser()
    try:
        examples = parser.get_examples(docstring, name="<doc>")
        return [(ex.source.strip(), ex.want.strip()) for ex in examples]
    except Exception:
        return []


def _to_assertion(expr: str, expected: str) -> str | None:
    """Convert one doctest example to a Python assertion.

    - True/False expected β†’ assert (expr) is True/False
    - Traceback expected  β†’ skip
    - Non-literal         β†’ skip
    """
    if not expected or expected.startswith("Traceback"):
        return None
    if expected in ("True", "False"):
        return f"assert ({expr}) is {expected}, f'got {{repr({expr})}}'"
    try:
        ast.literal_eval(expected)
    except (ValueError, SyntaxError):
        return None
    return f"assert {expr} == {expected}, f'got {{repr({expr})}}'"


def _build_test_code(func_name: str, module_stem: str, repo_name: str,
                     all_examples: list[tuple[str, str]]) -> str | None:
    """Build complete test code including setup lines then assertions."""
    import_line = f"from {repo_name}.{module_stem} import {func_name}"
    setup_lines: list[str] = []
    assertion_lines: list[str] = []

    for expr, expected in all_examples:
        if not expected:
            setup_lines.append(expr)
        else:
            a = _to_assertion(expr, expected)
            if a and func_name in a:   # only keep assertions that call our function
                assertion_lines.append(a)

    if len(assertion_lines) < 2:
        return None
    parts = [import_line] + setup_lines + assertion_lines
    return "\n".join(parts)


# ── Function stubbing ─────────────────────────────────────────────────────────

def _stub_function(source: str) -> str:
    """Replace a function body with `raise NotImplementedError`, keeping signature + docstring."""
    dedented = textwrap.dedent(source)
    try:
        tree = ast.parse(dedented)
    except SyntaxError:
        return source

    lines = dedented.splitlines()
    for node in ast.walk(tree):
        if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
            continue

        body = node.body
        indent = "    " * (node.col_offset // 4 + 1)

        # Keep signature lines (everything up to and including the colon)
        sig_end = body[0].lineno - 1  # 0-indexed line where body starts

        # Keep docstring if present
        if body and isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant):
            keep_until = body[0].end_lineno  # inclusive, 1-indexed
        else:
            keep_until = sig_end

        kept = "\n".join(lines[:keep_until])
        stub = kept.rstrip() + f"\n{indent}raise NotImplementedError\n"
        return stub

    return source


# ── Candidate selection ───────────────────────────────────────────────────────

def _score_candidate(node: KGNode, examples: list) -> int:
    """Higher = better training signal. Prefer more examples and longer docstrings."""
    return len(examples) * 3 + min(len(node.docstring or ""), 200) // 20


def _find_candidates(kg: KnowledgeGraph, repo_name: str) -> list[tuple[KGNode, str, int]]:
    """Return (node, test_code, score) for all viable candidates."""
    candidates = []
    for node in kg.all_nodes("function"):
        if node.name.startswith("_"):
            continue
        if not node.docstring or not node.source:
            continue
        module_stem = Path(node.file_path).stem if node.file_path else None
        if not module_stem:
            continue

        examples = _extract_all_examples(node.docstring)
        if not examples:
            continue

        test_code = _build_test_code(node.name, module_stem, repo_name, examples)
        if not test_code:
            continue

        score = _score_candidate(node, examples)
        candidates.append((node, test_code, score))

    candidates.sort(key=lambda x: x[2], reverse=True)
    return candidates


# ── Main entry point ──────────────────────────────────────────────────────────

def generate_tasks(
    repo_source_dir: str,
    n_tasks: int = 4,
    max_turns: int = 12,
) -> tuple[KnowledgeGraph, list[AutoTask]]:
    """Parse a Python repo directory and auto-generate training tasks.

    Args:
        repo_source_dir: Path to the Python package source directory.
                         e.g. '/tmp/humanize/src/humanize'
        n_tasks: How many tasks to generate (picks highest-scoring candidates).
        max_turns: Max turns per episode.

    Returns:
        (kg, tasks) β€” the Knowledge Graph and the list of AutoTask objects.
    """
    repo_source_dir = str(Path(repo_source_dir).resolve())
    repo_name = Path(repo_source_dir).name
    kg = parse_repo(repo_source_dir)

    candidates = _find_candidates(kg, repo_name)
    if not candidates:
        raise ValueError(
            f"No suitable candidates found in {repo_source_dir}. "
            "Make sure functions have doctest examples (>>> in docstring)."
        )

    selected = candidates[:n_tasks]
    tasks: list[AutoTask] = []

    for node, test_code, score in selected:
        stubbed = _stub_function(node.source)
        desc = textwrap.dedent(f"""\
            Implement the function `{node.name}` in `{node.file_path}`.

            {node.docstring.strip() if node.docstring else 'No docstring available.'}
        """).strip()

        task = AutoTask(
            task_id=f"auto.{repo_name}.{node.name}",
            repo_name=repo_name,
            repo_path=repo_source_dir,
            description=desc,
            test_code=test_code,
            stubbed_node_id=node.node_id,
            original_source=node.source,
            max_turns=max_turns,
            difficulty=min(2, max(0, score // 8)),
            hints=[
                f"Look at {node.file_path} to understand the module style.",
                f"The function signature is: {node.name}{node.metadata.get('signature', '(...)')}",
            ],
        )
        tasks.append(task)

    return kg, tasks