Spaces:
Sleeping
Sleeping
| """Auto-generate training tasks from any Python repository. | |
| Pipeline | |
| -------- | |
| 1. Parse the repo with AST β KnowledgeGraph | |
| 2. Find public functions that have doctest examples (>>> in docstring) | |
| 3. Extract those examples as runnable assertions | |
| 4. Replace the function body with `raise NotImplementedError` β the agent | |
| must re-implement it from the docstring alone | |
| 5. Return RepoTask objects ready for GRPO training β no hand-writing needed | |
| Usage | |
| ----- | |
| from graphforge.task_generator import generate_tasks | |
| tasks = generate_tasks("/tmp/humanize/src/humanize", n_tasks=6) | |
| for t in tasks: | |
| print(t.task_id, "β", t.description[:60]) | |
| """ | |
| from __future__ import annotations | |
| import ast | |
| import doctest | |
| import textwrap | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| from graphforge.knowledge_graph import KGNode, KnowledgeGraph | |
| from graphforge.repo_parser import parse_repo | |
| # ββ Task dataclass (mirrors env.tasks.RepoTask but lives here to avoid circular import) ββ | |
| class AutoTask: | |
| task_id: str | |
| repo_name: str | |
| repo_path: str # absolute path to the repo source directory | |
| description: str | |
| test_code: str # uses short import: from <repo_name>.<module> import <func> | |
| stubbed_node_id: str # the node whose body was replaced | |
| original_source: str # saved so env can restore on reset | |
| max_turns: int = 12 | |
| difficulty: int = 0 | |
| hints: list[str] = field(default_factory=list) | |
| # ββ Doctest extraction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _extract_all_examples(docstring: str) -> list[tuple[str, str]]: | |
| """Return ALL doctest lines as (source, want) β want is '' for setup lines.""" | |
| if not docstring: | |
| return [] | |
| parser = doctest.DocTestParser() | |
| try: | |
| examples = parser.get_examples(docstring, name="<doc>") | |
| return [(ex.source.strip(), ex.want.strip()) for ex in examples] | |
| except Exception: | |
| return [] | |
| def _to_assertion(expr: str, expected: str) -> str | None: | |
| """Convert one doctest example to a Python assertion. | |
| - True/False expected β assert (expr) is True/False | |
| - Traceback expected β skip | |
| - Non-literal β skip | |
| """ | |
| if not expected or expected.startswith("Traceback"): | |
| return None | |
| if expected in ("True", "False"): | |
| return f"assert ({expr}) is {expected}, f'got {{repr({expr})}}'" | |
| try: | |
| ast.literal_eval(expected) | |
| except (ValueError, SyntaxError): | |
| return None | |
| return f"assert {expr} == {expected}, f'got {{repr({expr})}}'" | |
| def _build_test_code(func_name: str, module_stem: str, repo_name: str, | |
| all_examples: list[tuple[str, str]]) -> str | None: | |
| """Build complete test code including setup lines then assertions.""" | |
| import_line = f"from {repo_name}.{module_stem} import {func_name}" | |
| setup_lines: list[str] = [] | |
| assertion_lines: list[str] = [] | |
| for expr, expected in all_examples: | |
| if not expected: | |
| setup_lines.append(expr) | |
| else: | |
| a = _to_assertion(expr, expected) | |
| if a and func_name in a: # only keep assertions that call our function | |
| assertion_lines.append(a) | |
| if len(assertion_lines) < 2: | |
| return None | |
| parts = [import_line] + setup_lines + assertion_lines | |
| return "\n".join(parts) | |
| # ββ Function stubbing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _stub_function(source: str) -> str: | |
| """Replace a function body with `raise NotImplementedError`, keeping signature + docstring.""" | |
| dedented = textwrap.dedent(source) | |
| try: | |
| tree = ast.parse(dedented) | |
| except SyntaxError: | |
| return source | |
| lines = dedented.splitlines() | |
| for node in ast.walk(tree): | |
| if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| continue | |
| body = node.body | |
| indent = " " * (node.col_offset // 4 + 1) | |
| # Keep signature lines (everything up to and including the colon) | |
| sig_end = body[0].lineno - 1 # 0-indexed line where body starts | |
| # Keep docstring if present | |
| if body and isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant): | |
| keep_until = body[0].end_lineno # inclusive, 1-indexed | |
| else: | |
| keep_until = sig_end | |
| kept = "\n".join(lines[:keep_until]) | |
| stub = kept.rstrip() + f"\n{indent}raise NotImplementedError\n" | |
| return stub | |
| return source | |
| # ββ Candidate selection βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _score_candidate(node: KGNode, examples: list) -> int: | |
| """Higher = better training signal. Prefer more examples and longer docstrings.""" | |
| return len(examples) * 3 + min(len(node.docstring or ""), 200) // 20 | |
| def _find_candidates(kg: KnowledgeGraph, repo_name: str) -> list[tuple[KGNode, str, int]]: | |
| """Return (node, test_code, score) for all viable candidates.""" | |
| candidates = [] | |
| for node in kg.all_nodes("function"): | |
| if node.name.startswith("_"): | |
| continue | |
| if not node.docstring or not node.source: | |
| continue | |
| module_stem = Path(node.file_path).stem if node.file_path else None | |
| if not module_stem: | |
| continue | |
| examples = _extract_all_examples(node.docstring) | |
| if not examples: | |
| continue | |
| test_code = _build_test_code(node.name, module_stem, repo_name, examples) | |
| if not test_code: | |
| continue | |
| score = _score_candidate(node, examples) | |
| candidates.append((node, test_code, score)) | |
| candidates.sort(key=lambda x: x[2], reverse=True) | |
| return candidates | |
| # ββ Main entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_tasks( | |
| repo_source_dir: str, | |
| n_tasks: int = 4, | |
| max_turns: int = 12, | |
| ) -> tuple[KnowledgeGraph, list[AutoTask]]: | |
| """Parse a Python repo directory and auto-generate training tasks. | |
| Args: | |
| repo_source_dir: Path to the Python package source directory. | |
| e.g. '/tmp/humanize/src/humanize' | |
| n_tasks: How many tasks to generate (picks highest-scoring candidates). | |
| max_turns: Max turns per episode. | |
| Returns: | |
| (kg, tasks) β the Knowledge Graph and the list of AutoTask objects. | |
| """ | |
| repo_source_dir = str(Path(repo_source_dir).resolve()) | |
| repo_name = Path(repo_source_dir).name | |
| kg = parse_repo(repo_source_dir) | |
| candidates = _find_candidates(kg, repo_name) | |
| if not candidates: | |
| raise ValueError( | |
| f"No suitable candidates found in {repo_source_dir}. " | |
| "Make sure functions have doctest examples (>>> in docstring)." | |
| ) | |
| selected = candidates[:n_tasks] | |
| tasks: list[AutoTask] = [] | |
| for node, test_code, score in selected: | |
| stubbed = _stub_function(node.source) | |
| desc = textwrap.dedent(f"""\ | |
| Implement the function `{node.name}` in `{node.file_path}`. | |
| {node.docstring.strip() if node.docstring else 'No docstring available.'} | |
| """).strip() | |
| task = AutoTask( | |
| task_id=f"auto.{repo_name}.{node.name}", | |
| repo_name=repo_name, | |
| repo_path=repo_source_dir, | |
| description=desc, | |
| test_code=test_code, | |
| stubbed_node_id=node.node_id, | |
| original_source=node.source, | |
| max_turns=max_turns, | |
| difficulty=min(2, max(0, score // 8)), | |
| hints=[ | |
| f"Look at {node.file_path} to understand the module style.", | |
| f"The function signature is: {node.name}{node.metadata.get('signature', '(...)')}", | |
| ], | |
| ) | |
| tasks.append(task) | |
| return kg, tasks | |