{ "cells": [ { "cell_type": "markdown", "id": "f70926e8", "metadata": {}, "source": [ "# DebugZero Training-First Workflow (Standalone)\n", "\n", "This notebook contains the complete `DebugZero` project inline. It includes the environment, executor, bug bank, seed bank, and the TRL GRPO training loop. You can run this sequentially in Colab or Kaggle." ] }, { "cell_type": "code", "execution_count": null, "id": "91a5f476", "metadata": {}, "outputs": [], "source": [ "!pip install -q trl transformers datasets openenv-core[core] pydantic thefuzz matplotlib bitsandbytes unsloth" ] }, { "cell_type": "code", "execution_count": null, "id": "5a22b902", "metadata": {}, "outputs": [], "source": [ "# 1. Models\n", "from openenv.core.env_server.types import Action, Observation, State\n", "from pydantic import Field\n", "\n", "class DebugzeroAction(Action):\n", " role: str = Field(..., description=\"Role taking action: 'proposer' or 'solver'\")\n", " code: str = Field(..., description=\"Code injected (by proposer) or fixed (by solver)\")\n", "\n", "class DebugzeroObservation(Observation):\n", " role_next: str = Field(default=\"proposer\", description=\"The role supposed to play next\")\n", " current_code: str = Field(default=\"\", description=\"The current state of the python code\")\n", " execution_result: str = Field(default=\"\", description=\"Result of evaluating tests\")\n", " tests_passed: bool = Field(default=False, description=\"Whether the tests passed\")\n", " syntax_error: bool = Field(default=False, description=\"Whether the code had a parse/syntax error\")\n", "\n", "class DebugzeroState(State):\n", " seed_id: str = Field(default=\"\", description=\"ID of the HumanEval function\")\n", " original_code: str = Field(default=\"\", description=\"Original clean seed code\")\n", " current_code: str = Field(default=\"\", description=\"Current code after turn\")\n", " role_turn: str = Field(default=\"proposer\", description=\"Current turn's role\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d0564e89", "metadata": {}, "outputs": [], "source": [ "# 2. Seed Bank\n", "from dataclasses import dataclass\n", "\n", "@dataclass(frozen=True)\n", "class SeedSpec:\n", " seed_id: str\n", " entrypoint: str\n", " prompt: str\n", " canonical_solution: str\n", " test: str\n", "\n", " @property\n", " def original_code(self) -> str:\n", " return f\"{self.prompt}\\n{self.canonical_solution}\"\n", "\n", "SEED_BANK = (\n", " SeedSpec(\"HumanEval/0\", \"has_close_elements\", \"def has_close_elements(numbers: list[float], threshold: float) -> bool:\", \" for idx, elem in enumerate(numbers):\\n for idx2, elem2 in enumerate(numbers):\\n if idx != idx2:\\n distance = abs(elem - elem2)\\n if distance < threshold:\\n return True\\n return False\\n\", \"def check(candidate):\\n assert candidate([1.0, 2.0, 3.0], 0.5) is False\\n assert candidate([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3) is True\\n\\ncheck(has_close_elements)\\n\"),\n", " SeedSpec(\"DebugZero/1\", \"sum_to_n\", \"def sum_to_n(n: int) -> int:\", \" total = 0\\n for value in range(n + 1):\\n total += value\\n return total\\n\", \"def check(candidate):\\n assert candidate(0) == 0\\n assert candidate(1) == 1\\n assert candidate(5) == 15\\n assert candidate(10) == 55\\n\\ncheck(sum_to_n)\\n\"),\n", " SeedSpec(\"DebugZero/2\", \"middle_slice\", \"def middle_slice(values: list[int]) -> list[int]:\", \" if len(values) <= 2:\\n return []\\n return values[1:-1]\\n\", \"def check(candidate):\\n assert candidate([1]) == []\\n assert candidate([1, 2]) == []\\n assert candidate([1, 2, 3]) == [2]\\n assert candidate([1, 2, 3, 4, 5]) == [2, 3, 4]\\n\\ncheck(middle_slice)\\n\"),\n", " SeedSpec(\"DebugZero/3\", \"is_non_decreasing\", \"def is_non_decreasing(values: list[int]) -> bool:\", \" return all(values[idx] <= values[idx + 1] for idx in range(len(values) - 1))\\n\", \"def check(candidate):\\n assert candidate([]) is True\\n assert candidate([5]) is True\\n assert candidate([1, 2, 2, 3]) is True\\n assert candidate([3, 2]) is False\\n assert candidate([1, 3, 2, 4]) is False\\n\\ncheck(is_non_decreasing)\\n\"),\n", " SeedSpec(\"DebugZero/4\", \"count_nonempty\", \"def count_nonempty(strings: list[str]) -> int:\", \" total = 0\\n for text in strings:\\n if len(text) > 0:\\n total += 1\\n return total\\n\", \"def check(candidate):\\n assert candidate([]) == 0\\n assert candidate(['', '']) == 0\\n assert candidate(['a', '', 'bc', '']) == 2\\n assert candidate(['hi', 'there']) == 2\\n\\ncheck(count_nonempty)\\n\"),\n", " SeedSpec(\"DebugZero/5\", \"running_max\", \"def running_max(values: list[int]) -> int:\", \" best = values[0]\\n for idx in range(1, len(values)):\\n if values[idx] > best:\\n best = values[idx]\\n return best\\n\", \"def check(candidate):\\n assert candidate([3]) == 3\\n assert candidate([3, 1, 5, 2]) == 5\\n assert candidate([-1, -4, -2]) == -1\\n assert candidate([0, 0, 0]) == 0\\n\\ncheck(running_max)\\n\"),\n", ")\n", "SEED_BY_ID = {seed.seed_id: seed for seed in SEED_BANK}\n", "def get_seed_by_id(seed_id: str) -> SeedSpec: return SEED_BY_ID[seed_id]" ] }, { "cell_type": "code", "execution_count": null, "id": "0ed909d0", "metadata": {}, "outputs": [], "source": [ "# 3. Executor\n", "import os\n", "import subprocess\n", "import sys\n", "import tempfile\n", "import ast\n", "\n", "BLOCKED_IMPORTS = [\"os\", \"sys\", \"subprocess\", \"shutil\", \"pathlib\"]\n", "BLOCKED_BUILTINS = [\"__import__\", \"eval\", \"exec\", \"open\"]\n", "\n", "def is_safe(code: str) -> bool:\n", " for mod in BLOCKED_IMPORTS:\n", " if f\"import {mod}\" in code or f\"from {mod}\" in code:\n", " return False\n", " for b in BLOCKED_BUILTINS:\n", " if b in code: return False\n", " try:\n", " tree = ast.parse(code)\n", " except SyntaxError:\n", " return False\n", " for node in ast.walk(tree):\n", " if isinstance(node, ast.Import):\n", " for alias in node.names:\n", " if alias.name.split('.')[0] in BLOCKED_IMPORTS: return False\n", " elif isinstance(node, ast.ImportFrom):\n", " if node.module and node.module.split('.')[0] in BLOCKED_IMPORTS: return False\n", " elif isinstance(node, ast.Call):\n", " if isinstance(node.func, ast.Name) and node.func.id in BLOCKED_BUILTINS: return False\n", " return True\n", "\n", "class ExecutionResult:\n", " def __init__(self, passed: bool, output: str, syntax_error: bool = False, timeout_error: bool = False):\n", " self.passed = passed\n", " self.output = output\n", " self.syntax_error = syntax_error\n", " self.timeout_error = timeout_error\n", "\n", "def execute_code(code: str, tests: str, timeout: int = 5) -> ExecutionResult:\n", " full_code = f\"{code}\\n\\n{tests}\"\n", " if not is_safe(full_code):\n", " try:\n", " ast.parse(full_code)\n", " return ExecutionResult(passed=False, output=\"Unsafe import detected.\", syntax_error=False)\n", " except SyntaxError as e:\n", " return ExecutionResult(passed=False, output=f\"SyntaxError: {e}\", syntax_error=True)\n", " \n", " with tempfile.TemporaryDirectory() as temp_dir:\n", " temp_file = os.path.join(temp_dir, \"exec_script.py\")\n", " with open(temp_file, \"w\") as f: f.write(full_code)\n", " try:\n", " result = subprocess.run([sys.executable, temp_file], capture_output=True, text=True, timeout=timeout)\n", " if result.returncode == 0:\n", " return ExecutionResult(passed=True, output=result.stdout)\n", " else:\n", " syntax_error = \"SyntaxError\" in result.stderr\n", " return ExecutionResult(passed=False, output=result.stderr, syntax_error=syntax_error)\n", " except subprocess.TimeoutExpired:\n", " return ExecutionResult(passed=False, output=\"Execution timed out.\", timeout_error=True)\n", " except Exception as e:\n", " return ExecutionResult(passed=False, output=str(e))" ] }, { "cell_type": "code", "execution_count": null, "id": "0f7263c1", "metadata": {}, "outputs": [], "source": [ "# 4. Plausibility Score\n", "from thefuzz import fuzz\n", "\n", "def compute_ast_distance(original_code: str, mutated_code: str) -> float:\n", " try:\n", " orig_ast = ast.dump(ast.parse(original_code))\n", " mut_ast = ast.dump(ast.parse(mutated_code))\n", " except SyntaxError:\n", " return 0.0\n", " ratio = fuzz.ratio(orig_ast, mut_ast)\n", " if 85 <= ratio: return 1.0 \n", " elif 50 <= ratio < 85: return max(0.1, (ratio - 50) / 35.0)\n", " else: return 0.0" ] }, { "cell_type": "code", "execution_count": null, "id": "10a37cca", "metadata": {}, "outputs": [], "source": [ "# 5. Prompts and Summarization (Dual Role Sampler)\n", "import re\n", "\n", "PROPOSER_PROMPT = \"\"\"You are the Proposer in a debugging self-play game.\n", "Given a clean Python function, inject a realistic logical bug into it.\n", "Rules:\n", "- Make exactly one small logical change.\n", "- Keep the code valid Python.\n", "- Keep the same function signature.\n", "- Preserve the overall structure and formatting as much as possible.\n", "- Prefer one of these mutation families: off_by_one, wrong_operator, wrong_builtin,\n", " condition_negation, loop_boundary_shift, or slice_boundary_corruption.\n", "- Aim for an edge-case behavior change, not a cosmetic refactor.\n", "- Avoid helper extraction, renaming-only edits, comment-only changes, or multi-line rewrites.\n", "- Return only the full modified Python code inside triple backticks.\n", "{focus_instruction}\n", "\n", "Clean function:\n", "```python\n", "{code}\n", "```\n", "\"\"\"\n", "\n", "SOLVER_PROMPT_CONCISE = \"\"\"You are the Solver in a debugging self-play game.\n", "Fix the bug with the smallest correct local change and return only the full fixed Python code inside triple backticks.\n", "\n", "Buggy function:\n", "```python\n", "{code}\n", "```\n", "\n", "Failure summary:\n", "{execution_result}\n", "\"\"\"\n", "\n", "TRACEBACK_HINTS = (\"Traceback\", \"AssertionError\", \"SyntaxError\", \"TypeError\", \"NameError\", \"ValueError\", \"IndexError\", \"KeyError\", \"ZeroDivisionError\", \"RuntimeError\", \"Timeout\",)\n", "\n", "def _truncate_text(text: str, max_chars: int) -> str:\n", " cleaned = re.sub(r\"[ \\t]+\", \" \", text.strip())\n", " if len(cleaned) <= max_chars: return cleaned\n", " return cleaned[: max(0, max_chars - 3)].rstrip() + \"...\"\n", "\n", "def summarize_failure_output(execution_result: str, *, max_lines: int = 3, max_chars: int = 220) -> str:\n", " text = execution_result.strip()\n", " if not text: return \"No failure output provided.\"\n", " if text in {\"Unsafe import detected.\", \"Execution timed out.\"} or text.startswith(\"SyntaxError:\"):\n", " return _truncate_text(text, max_chars)\n", " lines = [line.strip() for line in text.splitlines() if line.strip()]\n", " if not lines: return \"No failure output provided.\"\n", " traceback_positions = [idx for idx, line in enumerate(lines) if \"Traceback\" in line]\n", " if traceback_positions:\n", " tail = lines[traceback_positions[-1] :]\n", " if len(tail) > max_lines: lines = [tail[0], *tail[-(max_lines - 1) :]]\n", " else: lines = tail\n", " else:\n", " interesting_lines = [line for line in lines if any(hint in line for hint in TRACEBACK_HINTS)]\n", " if interesting_lines: lines = interesting_lines[-max_lines:]\n", " else: lines = lines[-max_lines:]\n", " return _truncate_text(\"\\n\".join(lines), max_chars)\n", "\n", "def sample_proposer_prompt(code: str, bug_focus: str | None = None) -> str:\n", " focus_instruction = \"\"\n", " if bug_focus:\n", " focus_instruction = f\"- Focus specifically on the `{bug_focus}` mutation family.\\n- Keep the edit local so the bug can be repaired with a small fix.\"\n", " return PROPOSER_PROMPT.format(code=code, focus_instruction=focus_instruction)\n", "\n", "def sample_solver_prompt(code: str, execution_result: str = \"\", *, mode: str = \"concise\") -> str:\n", " failure_output = summarize_failure_output(execution_result)\n", " return SOLVER_PROMPT_CONCISE.format(code=code, execution_result=failure_output)" ] }, { "cell_type": "code", "execution_count": null, "id": "db97c7b9", "metadata": {}, "outputs": [], "source": [ "# 6. Bug Operations & Injector\n", "import random\n", "import copy\n", "\n", "BUILTIN_PAIRS = {\"min\": \"max\", \"max\": \"min\", \"any\": \"all\", \"all\": \"any\", \"sum\": \"len\", \"len\": \"sum\"}\n", "\n", "def is_safe_injection(code: str) -> bool:\n", " for blocked in BLOCKED_IMPORTS:\n", " if f\"import {blocked}\" in code or f\"from {blocked}\" in code:\n", " return False\n", " return True\n", "\n", "class BugInjectorVisitor(ast.NodeTransformer):\n", " def __init__(self, target_operator: str):\n", " super().__init__()\n", " self.target_operator = target_operator\n", " self.mutated = False\n", "\n", " def visit_Constant(self, node):\n", " self.generic_visit(node)\n", " if self.mutated: return node\n", " if self.target_operator == \"off_by_one\" and isinstance(node.value, int) and not isinstance(node.value, bool):\n", " node.value += random.choice([-1, 1])\n", " self.mutated = True\n", " return node\n", "\n", " def visit_Compare(self, node):\n", " self.generic_visit(node)\n", " if self.mutated: return node\n", " if self.target_operator == \"wrong_operator\":\n", " if isinstance(node.ops[0], ast.Lt):\n", " node.ops[0] = ast.GtE()\n", " self.mutated = True\n", " elif isinstance(node.ops[0], ast.LtE):\n", " node.ops[0] = ast.Gt()\n", " self.mutated = True\n", " elif isinstance(node.ops[0], ast.Gt):\n", " node.ops[0] = ast.LtE()\n", " self.mutated = True\n", " elif isinstance(node.ops[0], ast.GtE):\n", " node.ops[0] = ast.Lt()\n", " self.mutated = True\n", " elif isinstance(node.ops[0], ast.Eq):\n", " node.ops[0] = ast.NotEq()\n", " self.mutated = True\n", " elif isinstance(node.ops[0], ast.NotEq):\n", " node.ops[0] = ast.Eq()\n", " self.mutated = True\n", " return node\n", "\n", " def visit_BinOp(self, node):\n", " self.generic_visit(node)\n", " if self.mutated: return node\n", " if self.target_operator == \"wrong_operator\":\n", " if isinstance(node.op, ast.Add):\n", " node.op = ast.Sub()\n", " self.mutated = True\n", " elif isinstance(node.op, ast.Sub):\n", " node.op = ast.Add()\n", " self.mutated = True\n", " elif isinstance(node.op, ast.Mult):\n", " node.op = ast.FloorDiv()\n", " self.mutated = True\n", " elif isinstance(node.op, ast.Div):\n", " node.op = ast.Mult()\n", " self.mutated = True\n", " return node\n", " \n", " def visit_Call(self, node):\n", " self.generic_visit(node)\n", " if self.mutated: return node\n", " if isinstance(node.func, ast.Name):\n", " if self.target_operator == \"wrong_builtin\" and node.func.id in BUILTIN_PAIRS:\n", " node.func.id = BUILTIN_PAIRS[node.func.id]\n", " self.mutated = True\n", " elif self.target_operator == \"loop_boundary_shift\" and node.func.id == \"range\":\n", " if len(node.args) == 1:\n", " node.args[0] = ast.BinOp(left=node.args[0], op=ast.Add(), right=ast.Constant(value=1))\n", " self.mutated = True\n", " elif len(node.args) == 2:\n", " node.args[0] = ast.BinOp(left=node.args[0], op=ast.Sub(), right=ast.Constant(value=1))\n", " self.mutated = True\n", " return node\n", "\n", " def visit_If(self, node):\n", " self.generic_visit(node)\n", " if self.mutated: return node\n", " if self.target_operator == \"condition_negation\":\n", " node.test = ast.UnaryOp(op=ast.Not(), operand=node.test)\n", " self.mutated = True\n", " if self.target_operator == \"missing_base_case\":\n", " for idx, child in enumerate(node.body):\n", " if isinstance(child, ast.Return):\n", " node.body[idx] = ast.Pass()\n", " self.mutated = True\n", " break\n", " return node\n", " \n", " def visit_Slice(self, node):\n", " self.generic_visit(node)\n", " if self.mutated: return node\n", " if self.target_operator == \"slice_boundary_corruption\":\n", " if node.lower is not None:\n", " node.lower = ast.BinOp(left=node.lower, op=ast.Add(), right=ast.Constant(value=1))\n", " self.mutated = True\n", " elif node.upper is not None:\n", " node.upper = ast.BinOp(left=node.upper, op=ast.Sub(), right=ast.Constant(value=1))\n", " self.mutated = True\n", " return node\n", " \n", " def visit_Assign(self, node):\n", " self.generic_visit(node)\n", " if self.mutated: return node\n", " if self.target_operator == \"variable_swap\" and getattr(node, \"targets\", None):\n", " if isinstance(node.targets[0], ast.Tuple) and len(node.targets[0].elts) >= 2:\n", " node.targets[0].elts[0], node.targets[0].elts[1] = node.targets[0].elts[1], node.targets[0].elts[0]\n", " self.mutated = True\n", " return node\n", "\n", "def inject_bug(original_code: str, proposed_operator: str) -> tuple[str, bool]:\n", " try: tree = ast.parse(original_code)\n", " except SyntaxError: return original_code, False\n", " injector = BugInjectorVisitor(proposed_operator)\n", " mutated_tree = injector.visit(copy.deepcopy(tree))\n", " ast.fix_missing_locations(mutated_tree)\n", " mutated_code = ast.unparse(mutated_tree)\n", " if mutated_code.strip() == original_code.strip(): return original_code, False\n", " if not is_safe_injection(mutated_code): return original_code, False\n", " try: ast.parse(mutated_code)\n", " except SyntaxError: return original_code, False\n", " return mutated_code, True\n", "\n", "def infer_bug_operator(original_code: str, candidate_code: str) -> str | None:\n", " try:\n", " original_tree = ast.parse(original_code)\n", " candidate_tree = ast.parse(candidate_code)\n", " except SyntaxError:\n", " return None\n", " if ast.dump(original_tree) == ast.dump(candidate_tree):\n", " return None\n", " return \"unknown\" # simplified for inline layout logic" ] }, { "cell_type": "code", "execution_count": null, "id": "73ba2c61", "metadata": {}, "outputs": [], "source": [ "# 7. Environment & Bug Bank\n", "V1_BUG_OPERATORS = (\"wrong_operator\", \"wrong_builtin\", \"condition_negation\", \"off_by_one\", \"loop_boundary_shift\", \"slice_boundary_corruption\",)\n", "MAX_VERIFIED_BUGS_PER_SEED = 4\n", "HOLDOUT_BUGS_PER_SEED = 1\n", "MAX_MUTATION_ATTEMPTS = 4\n", "BUG_OPERATOR_PRIORITY = {\"loop_boundary_shift\": 6, \"slice_boundary_corruption\": 5, \"condition_negation\": 4, \"wrong_operator\": 3, \"off_by_one\": 2, \"wrong_builtin\": 1}\n", "\n", "@dataclass(frozen=True)\n", "class BugSample:\n", " seed_id: str\n", " original_code: str\n", " buggy_code: str\n", " bug_operator: str\n", " execution_result: str\n", "\n", "@dataclass(frozen=True)\n", "class BugBank:\n", " train_samples: tuple[BugSample, ...]\n", " eval_samples: tuple[BugSample, ...]\n", "\n", "def validate_seed(seed: SeedSpec) -> None:\n", " result = execute_code(seed.original_code, seed.test)\n", " if result.syntax_error or not result.passed:\n", " raise ValueError(f\"Seed {seed.seed_id} does not pass.\")\n", "\n", "def _count_nonempty_lines(text: str) -> int:\n", " return sum(1 for line in text.splitlines() if line.strip())\n", "\n", "def _bug_difficulty_score(seed: SeedSpec, sample: BugSample) -> float:\n", " operator_score = BUG_OPERATOR_PRIORITY.get(sample.bug_operator, 0)\n", " ast_similarity = compute_ast_distance(seed.original_code, sample.buggy_code)\n", " execution_lines = _count_nonempty_lines(sample.execution_result)\n", " return float(operator_score) + ast_similarity + min(execution_lines / 4.0, 1.0)\n", "\n", "def _collect_verified_bugs(seed: SeedSpec) -> list[BugSample]:\n", " verified_samples: list[BugSample] = []\n", " seen_codes: set[str] = set()\n", " for bug_operator in V1_BUG_OPERATORS:\n", " for attempt in range(MAX_MUTATION_ATTEMPTS):\n", " random.seed(f\"{seed.seed_id}:{bug_operator}:{attempt}\")\n", " buggy_code, changed = inject_bug(seed.original_code, bug_operator)\n", " if not changed or buggy_code in seen_codes: continue\n", " result = execute_code(buggy_code, seed.test)\n", " if result.syntax_error or result.passed: continue\n", " seen_codes.add(buggy_code)\n", " verified_samples.append(BugSample(seed.seed_id, seed.original_code, buggy_code, bug_operator, result.output[:500] if result.output else \"\"))\n", " return verified_samples\n", "\n", "def build_bug_bank() -> BugBank:\n", " train_samples, eval_samples = [], []\n", " for seed in SEED_BANK:\n", " validate_seed(seed)\n", " verified_samples = sorted(_collect_verified_bugs(seed), key=lambda sample: _bug_difficulty_score(seed, sample), reverse=True)\n", " if len(verified_samples) <= HOLDOUT_BUGS_PER_SEED: raise ValueError(f\"Seed {seed.seed_id} under-produced.\")\n", " eval_samples.extend(verified_samples[:HOLDOUT_BUGS_PER_SEED])\n", " train_samples.extend(verified_samples[HOLDOUT_BUGS_PER_SEED : HOLDOUT_BUGS_PER_SEED + MAX_VERIFIED_BUGS_PER_SEED])\n", " return BugBank(tuple(train_samples), tuple(eval_samples))" ] }, { "cell_type": "code", "execution_count": null, "id": "2736e2d9", "metadata": {}, "outputs": [], "source": [ "# 8. Training Rewards\n", "import statistics\n", "from collections import deque\n", "\n", "solve_rate_history: dict[str, deque[float]] = {}\n", "def reset_reward_history() -> None: solve_rate_history.clear()\n", "def get_solve_rate(seed_id: str) -> float: return statistics.mean(solve_rate_history[seed_id]) if solve_rate_history.get(seed_id) else 0.5\n", "def record_solve_result(seed_id: str, solved: bool) -> None:\n", " if seed_id not in solve_rate_history: solve_rate_history[seed_id] = deque(maxlen=20)\n", " solve_rate_history[seed_id].append(1.0 if solved else 0.0)\n", "\n", "def is_effectively_unchanged(original_code: str, candidate_code: str) -> bool:\n", " try: return ast.dump(ast.parse(original_code)) == ast.dump(ast.parse(candidate_code))\n", " except SyntaxError: return original_code.strip() == candidate_code.strip()\n", "\n", "def compute_proposer_reward(meta: dict) -> float:\n", " if meta.get(\"syntax_error\", False) or meta.get(\"unsafe_code\", False): return -0.5\n", " if meta.get(\"unchanged_code\", False) or meta.get(\"tests_passed\", True): return 0.0\n", " if meta.get(\"changed_but_passing\", False): return -0.1\n", " plausibility_bonus = meta.get(\"plausibility_score\", 0.0)\n", " learnability_bonus = 1.0 if 0.2 <= get_solve_rate(meta[\"seed_id\"]) <= 0.8 else 0.0\n", " return 1.0 + plausibility_bonus + learnability_bonus\n", "\n", "def compute_solver_reward(meta: dict) -> float:\n", " solved = meta.get(\"tests_passed\", False)\n", " syntax_error = meta.get(\"syntax_error\", True)\n", " unsafe_code = meta.get(\"unsafe_code\", False)\n", " record_solve_result(meta[\"seed_id\"], solved and not syntax_error and not unsafe_code)\n", " if syntax_error or unsafe_code: return -0.5\n", " if solved: return 1.0\n", " return 0.0" ] }, { "cell_type": "code", "execution_count": null, "id": "892a82b2", "metadata": {}, "outputs": [], "source": [ "# 9. Build the Dataset\n", "import math\n", "from datasets import Dataset\n", "from collections import Counter, defaultdict\n", "\n", "DEFAULT_SOLVER_WEIGHT = 2\n", "TARGETED_PROMPT_RATIO = 0.75\n", "\n", "def choose_proposer_bug_focus(seed_id: str, operators: list, operator_counts: Counter, focus_counters: Counter, row_index: int, total_rows: int) -> str | None:\n", " unique_operators = sorted(set(operators), key=lambda op: (operator_counts[op], op))\n", " if not unique_operators: return None\n", " if row_index >= math.ceil(total_rows * TARGETED_PROMPT_RATIO): return None\n", " chosen = min(unique_operators, key=lambda op: (focus_counters[op], operator_counts[op], op))\n", " focus_counters[chosen] += 1\n", " return chosen\n", "\n", "def build_weighted_proposer_rows(bug_bank, target_proposer_rows: int) -> list:\n", " if target_proposer_rows <= 0: return []\n", " operator_counts = Counter(sample.bug_operator for sample in bug_bank.train_samples)\n", " seed_to_operators = defaultdict(list)\n", " for sample in bug_bank.train_samples:\n", " seed_to_operators[sample.seed_id].append(sample.bug_operator)\n", " \n", " seed_weights = {seed.seed_id: 2 for seed in SEED_BANK} # Default weight for inline\n", " rows = []\n", " focus_counters = Counter()\n", " ordered_seeds = sorted(SEED_BANK, key=lambda seed: (-seed_weights[seed.seed_id], seed.seed_id))\n", "\n", " for seed in SEED_BANK[:target_proposer_rows]:\n", " bug_focus = choose_proposer_bug_focus(seed.seed_id, seed_to_operators[seed.seed_id], operator_counts, focus_counters, len(rows), target_proposer_rows)\n", " prompt_text = sample_proposer_prompt(seed.original_code, bug_focus=bug_focus)\n", " rows.append({\"prompt\": [{\"role\": \"user\", \"content\": prompt_text}], \"role\": \"proposer\", \"seed_id\": seed.seed_id, \"original_code\": seed.original_code, \"bug_focus\": bug_focus if bug_focus else \"\"})\n", "\n", " while len(rows) < target_proposer_rows:\n", " for seed in ordered_seeds:\n", " extra_weight = max(0, seed_weights[seed.seed_id] - 1)\n", " for _ in range(extra_weight):\n", " if len(rows) >= target_proposer_rows: break\n", " bug_focus = choose_proposer_bug_focus(seed.seed_id, seed_to_operators[seed.seed_id], operator_counts, focus_counters, len(rows), target_proposer_rows)\n", " prompt_text = sample_proposer_prompt(seed.original_code, bug_focus=bug_focus)\n", " rows.append({\"prompt\": [{\"role\": \"user\", \"content\": prompt_text}], \"role\": \"proposer\", \"seed_id\": seed.seed_id, \"original_code\": seed.original_code, \"bug_focus\": bug_focus if bug_focus else \"\"})\n", " if len(rows) >= target_proposer_rows: break\n", " return rows\n", "\n", "def build_mixed_role_dataset(bug_bank) -> Dataset:\n", " rows = []\n", " for bug_sample in bug_bank.train_samples:\n", " prompt_text = sample_solver_prompt(bug_sample.buggy_code, bug_sample.execution_result)\n", " rows.append({\n", " \"prompt\": [{\"role\": \"user\", \"content\": prompt_text}],\n", " \"role\": \"solver\", \"seed_id\": bug_sample.seed_id, \"original_code\": bug_sample.original_code, \"buggy_code\": bug_sample.buggy_code\n", " })\n", " target_proposer_rows = max(1, math.ceil(len(rows) / DEFAULT_SOLVER_WEIGHT)) if rows else len(SEED_BANK)\n", " rows.extend(build_weighted_proposer_rows(bug_bank, target_proposer_rows))\n", " return Dataset.from_list(rows)\n", "\n", "dataset, bug_bank = build_mixed_role_dataset(build_bug_bank()), build_bug_bank()\n", "print(\"Dataset size:\", len(dataset))" ] }, { "cell_type": "code", "execution_count": null, "id": "e85498cb", "metadata": {}, "outputs": [], "source": [ "# 10. TRL GRPO Training Setup\n", "import torch\n", "import importlib.util\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "from trl import GRPOConfig, GRPOTrainer\n", "import re\n", "\n", "DEFAULT_MODEL_ID = \"Qwen/Qwen2.5-Coder-3B-Instruct\" # Recommended default from DebugZero\n", "DEFAULT_MAX_PROMPT_LENGTH = 768\n", "DEFAULT_MAX_COMPLETION_LENGTH = 256\n", "\n", "def extract_python_code(text: str) -> str:\n", " match = re.search(r\"```(?:python)?\\s(.*?)```\", text, flags=re.DOTALL)\n", " if match: return match.group(1).strip()\n", " return text.strip()\n", "\n", "def completion_to_text(completion) -> str:\n", " if isinstance(completion, list) and completion:\n", " item = completion[0]\n", " return item.get(\"content\", \"\") if isinstance(item, dict) else str(item)\n", " return str(completion)\n", "\n", "def execute_candidate(seed: SeedSpec, candidate_code: str) -> dict[str, object]:\n", " result = execute_code(candidate_code, seed.test)\n", " execution_result = result.output[:500] if result.output else \"\"\n", " return {\n", " \"tests_passed\": result.passed, \"syntax_error\": result.syntax_error,\n", " \"unsafe_code\": execution_result.startswith(\"Unsafe import detected.\"),\n", " \"execution_result\": execution_result,\n", " }\n", "\n", "def prop_rew(prompts, completions, **kwargs):\n", " rewards = []\n", " roles = kwargs.get(\"role\", [])\n", " seed_ids = kwargs.get(\"seed_id\", [])\n", " original_codes = kwargs.get(\"original_code\", [])\n", " for i, completion in enumerate(completions):\n", " role = roles[i] if i < len(roles) else roles[0]\n", " if role != \"proposer\":\n", " rewards.append(0.0)\n", " continue\n", " \n", " seed_id = seed_ids[i] if i < len(seed_ids) else seed_ids[0]\n", " original_code = original_codes[i] if i < len(original_codes) else original_codes[0]\n", " seed = get_seed_by_id(seed_id)\n", " candidate_code = extract_python_code(completion_to_text(completion))\n", " exec_meta = execute_candidate(seed, candidate_code)\n", " \n", " unchanged = is_effectively_unchanged(original_code, candidate_code)\n", " proposer_meta = {\n", " \"seed_id\": seed.seed_id, \"tests_passed\": exec_meta[\"tests_passed\"], \"syntax_error\": exec_meta[\"syntax_error\"],\n", " \"unsafe_code\": exec_meta[\"unsafe_code\"], \"unchanged_code\": unchanged,\n", " \"changed_but_passing\": (not unchanged) and exec_meta[\"tests_passed\"] and (not exec_meta[\"syntax_error\"]),\n", " \"plausibility_score\": 0.0 if exec_meta[\"syntax_error\"] else compute_ast_distance(original_code, candidate_code)\n", " }\n", " rewards.append(compute_proposer_reward(proposer_meta))\n", " return rewards\n", "\n", "def solv_rew(prompts, completions, **kwargs):\n", " rewards = []\n", " roles = kwargs.get(\"role\", [])\n", " seed_ids = kwargs.get(\"seed_id\", [])\n", " for i, completion in enumerate(completions):\n", " role = roles[i] if i < len(roles) else roles[0]\n", " if role != \"solver\":\n", " rewards.append(0.0)\n", " continue\n", " \n", " seed_id = seed_ids[i] if i < len(seed_ids) else seed_ids[0]\n", " seed = get_seed_by_id(seed_id)\n", " candidate_code = extract_python_code(completion_to_text(completion))\n", " exec_meta = execute_candidate(seed, candidate_code)\n", " \n", " rewards.append(compute_solver_reward({\"seed_id\": seed.seed_id, \"tests_passed\": exec_meta[\"tests_passed\"], \"syntax_error\": exec_meta[\"syntax_error\"], \"unsafe_code\": exec_meta[\"unsafe_code\"]}))\n", " return rewards\n", "\n", "# Load Model\n", "model, tokenizer = None, None\n", "try:\n", " from unsloth import FastLanguageModel, PatchFastRL\n", " PatchFastRL(\"GRPO\", FastLanguageModel)\n", " model, tokenizer = FastLanguageModel.from_pretrained(model_name=\"unsloth/Qwen2.5-Coder-3B-Instruct\", max_seq_length=1024, load_in_4bit=True, fast_inference=True)\n", " model = FastLanguageModel.get_peft_model(model, r=16, target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], lora_alpha=16, bias=\"none\", use_gradient_checkpointing=\"unsloth\")\n", "except ImportError:\n", " # Unsloth is failing to load (e.g., due to Kaggle/Colab CUDA mismatch).\n", " # Falling back to standard HuggingFace PEFT (LoRA).\n", " from peft import LoraConfig, get_peft_model\n", " tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL_ID)\n", " model = AutoModelForCausalLM.from_pretrained(DEFAULT_MODEL_ID, torch_dtype=torch.bfloat16, device_map=\"auto\")\n", " peft_config = LoraConfig(r=16, lora_alpha=16, target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], lora_dropout=0, bias=\"none\", task_type=\"CAUSAL_LM\")\n", " model = get_peft_model(model, peft_config)\n", "\n", "if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token" ] }, { "cell_type": "code", "execution_count": null, "id": "c9cac092", "metadata": {}, "outputs": [], "source": [ "# 11. Run GRPO Training & Plot Metrics\n", "has_bitsandbytes = importlib.util.find_spec(\"bitsandbytes\") is not None\n", "\n", "training_args = GRPOConfig(\n", " output_dir=\"debugzero_model\",\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=4,\n", " learning_rate=2e-5,\n", " max_steps=50,\n", " num_generations=2,\n", " max_prompt_length=DEFAULT_MAX_PROMPT_LENGTH,\n", " max_completion_length=DEFAULT_MAX_COMPLETION_LENGTH,\n", " bf16=False, fp16=True,\n", " logging_steps=5,\n", " optim=\"adamw_8bit\" if has_bitsandbytes else \"adamw_torch\",\n", " report_to=\"none\",\n", " disable_tqdm=True,\n", ")\n", "\n", "from transformers import TrainerCallback\n", "\n", "class TableMetricsCallback(TrainerCallback):\n", " def on_train_begin(self, args, state, control, **kwargs):\n", " print(f\"{'Step':<8} | {'Loss':<10} | {'Prop Rew':<10} | {'Solv Rew':<10} | {'Tot Rew':<10} | {'Entropy':<10}\")\n", " print(\"-\" * 70)\n", " \n", " def on_log(self, args, state, control, logs=None, **kwargs):\n", " if logs and \"loss\" in logs:\n", " loss = logs.get(\"loss\", 0.0)\n", " p_reward = logs.get(\"rewards/prop_rew/mean\", 0.0)\n", " s_reward = logs.get(\"rewards/solv_rew/mean\", 0.0)\n", " total_reward = logs.get(\"reward\", p_reward + s_reward)\n", " entropy = logs.get(\"entropy\", 0.0)\n", " print(f\"{state.global_step:<8} | {loss:<10.4f} | {p_reward:<10.4f} | {s_reward:<10.4f} | {total_reward:<10.4f} | {entropy:<10.4f}\")\n", "\n", "trainer = GRPOTrainer(\n", "\n", " \n", "\n", " model=model,\n", " reward_funcs=[prop_rew, solv_rew],\n", " args=training_args,\n", " train_dataset=dataset,\n", " processing_class=tokenizer,\n", " callbacks=[TableMetricsCallback()],\n", ")\n", "\n", "print(f\"Starting GRPO training for {training_args.max_steps} episodes (steps)...\")\n", "print(\"To change the number of episodes, modify 'max_steps' in GRPOConfig above.\")\n", "train_result = trainer.train()\n", "print(\"Training Complete! View debugzero_model for artifacts.\")\n", "\n", "# 12. Plot Metrics natively in Colab\n", "import matplotlib.pyplot as plt\n", "\n", "log_history = trainer.state.log_history\n", "steps = [entry[\"step\"] for entry in log_history if \"loss\" in entry]\n", "losses = [entry[\"loss\"] for entry in log_history if \"loss\" in entry]\n", "\n", "p_rewards = []\n", "s_rewards = []\n", "\n", "for entry in log_history:\n", " if \"loss\" in entry:\n", " p_rewards.append(entry.get(\"rewards/prop_rew/mean\", entry.get(\"rewards/prop_rew/mean\", 0.0)))\n", " s_rewards.append(entry.get(\"rewards/solv_rew/mean\", entry.get(\"rewards/solv_rew/mean\", 0.0)))\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", "\n", "# Loss Plot\n", "if steps and losses:\n", " axes[0].plot(steps[:len(losses)], losses, marker='o', color='purple', label=\"Total Loss\")\n", " axes[0].set_title(\"GRPO Training Loss\")\n", " axes[0].set_xlabel(\"Steps (Episodes)\")\n", " axes[0].set_ylabel(\"Loss\")\n", " axes[0].grid(True, linestyle=\"--\", alpha=0.5)\n", " axes[0].legend()\n", "\n", "# Rewards Plot\n", "if steps and (p_rewards or s_rewards):\n", " if p_rewards:\n", " axes[1].plot(steps[:len(p_rewards)], p_rewards, marker='s', color='orange', label=\"Proposer Reward\")\n", " if s_rewards:\n", " axes[1].plot(steps[:len(s_rewards)], s_rewards, marker='^', color='green', label=\"Solver Reward\")\n", " \n", " axes[1].set_title(\"GRPO Environment Rewards Evolution\")\n", " axes[1].set_xlabel(\"Steps (Episodes)\")\n", " axes[1].set_ylabel(\"Reward\")\n", " axes[1].grid(True, linestyle=\"--\", alpha=0.5)\n", " axes[1].legend()\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "0048cecb", "metadata": {}, "outputs": [], "source": [ "# 13. Interactive Verification\n", "# We wrap tqdm around some final manual checks to give a visual indicator for eval.\n", "from tqdm.auto import tqdm\n", "\n", "print(\"Running final evaluations across the holdout set:\")\n", "model.eval()\n", "\n", "# Testing Solver\n", "correct = 0\n", "total_evals = len(bug_bank.eval_samples)\n", "\n", "print(f\"Validating {total_evals} Holdout bugs...\")\n", "for sample in tqdm(bug_bank.eval_samples, desc=\"Solver Eval\"):\n", " prompt = sample_solver_prompt(sample.buggy_code, sample.execution_result)\n", " \n", " prompt_text = tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": prompt}], tokenize=False, add_generation_prompt=True)\n", " encoded = tokenizer(prompt_text, return_tensors=\"pt\").to(model.device)\n", " out = model.generate(**encoded, max_new_tokens=200, pad_token_id=tokenizer.pad_token_id, do_sample=False)\n", " \n", " generated_code = tokenizer.decode(out[0][encoded.input_ids.shape[-1]:], skip_special_tokens=True)\n", " clean_code = extract_python_code(generated_code)\n", " \n", " # Check if the generated solution passes the test\n", " seed = get_seed_by_id(sample.seed_id)\n", " exec_meta = execute_candidate(seed, clean_code)\n", " \n", " if exec_meta[\"tests_passed\"] and not exec_meta[\"syntax_error\"]:\n", " correct += 1\n", "\n", "print(f\"Holdout Set Solver Pass Rate: {correct}/{total_evals} ({correct/total_evals:.1%})\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11" } }, "nbformat": 4, "nbformat_minor": 5 }