""" CodeSensei — CodeDebug OpenEnv Environment. Implements the core Environment with reset(), step(), and state following the OpenEnv 3-method pattern. Manages episodes, tracks attempts, computes multi-signal rewards, and detects repeated fixes. """ from __future__ import annotations import hashlib import json import os import random import uuid from typing import Dict, List, Optional, Tuple from env.models import CodeDebugAction, CodeDebugObservation, CodeDebugState, TestResult from env.server.test_runner import run_tests from env.server.sandbox import check_syntax # Load bug dataset _DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") _BUG_DATASET: List[dict] = [] def _load_dataset(): global _BUG_DATASET dataset_path = os.path.join(_DATA_DIR, "bug_dataset.json") if os.path.exists(dataset_path): with open(dataset_path, "r", encoding="utf-8") as f: _BUG_DATASET = json.load(f) else: # Fallback: built-in minimal dataset for testing _BUG_DATASET = _get_builtin_dataset() def _get_builtin_dataset() -> List[dict]: """Built-in minimal dataset used when bug_dataset.json is missing.""" return [ { "function_name": "add_numbers", "buggy_code": "def add_numbers(a, b):\n return a - b", "correct_code": "def add_numbers(a, b):\n return a + b", "bug_description": "Uses subtraction instead of addition", "tests": [ {"name": "basic addition", "code": "assert add_numbers(2, 3) == 5"}, {"name": "zero addition", "code": "assert add_numbers(0, 0) == 0"}, {"name": "negative addition", "code": "assert add_numbers(-1, 1) == 0"}, ], }, { "function_name": "find_max", "buggy_code": "def find_max(lst):\n if not lst:\n return None\n result = lst[0]\n for x in lst:\n if x < result:\n result = x\n return result", "correct_code": "def find_max(lst):\n if not lst:\n return None\n result = lst[0]\n for x in lst:\n if x > result:\n result = x\n return result", "bug_description": "Uses < instead of > (finds minimum instead of maximum)", "tests": [ {"name": "basic max", "code": "assert find_max([1, 3, 2]) == 3"}, {"name": "single element", "code": "assert find_max([5]) == 5"}, {"name": "negative numbers", "code": "assert find_max([-1, -5, -2]) == -1"}, {"name": "empty list", "code": "assert find_max([]) is None"}, ], }, { "function_name": "reverse_string", "buggy_code": "def reverse_string(s):\n return s[1:]", "correct_code": "def reverse_string(s):\n return s[::-1]", "bug_description": "Slices from index 1 instead of reversing", "tests": [ {"name": "basic reverse", "code": 'assert reverse_string("hello") == "olleh"'}, {"name": "empty string", "code": 'assert reverse_string("") == ""'}, {"name": "single char", "code": 'assert reverse_string("a") == "a"'}, {"name": "palindrome", "code": 'assert reverse_string("racecar") == "racecar"'}, ], }, { "function_name": "fibonacci", "buggy_code": "def fibonacci(n):\n if n <= 0:\n return 0\n if n == 1:\n return 1\n return fibonacci(n - 1) + fibonacci(n - 3)", "correct_code": "def fibonacci(n):\n if n <= 0:\n return 0\n if n == 1:\n return 1\n return fibonacci(n - 1) + fibonacci(n - 2)", "bug_description": "Recursive call uses n-3 instead of n-2", "tests": [ {"name": "fib(0)", "code": "assert fibonacci(0) == 0"}, {"name": "fib(1)", "code": "assert fibonacci(1) == 1"}, {"name": "fib(5)", "code": "assert fibonacci(5) == 5"}, {"name": "fib(10)", "code": "assert fibonacci(10) == 55"}, ], }, { "function_name": "count_vowels", "buggy_code": "def count_vowels(s):\n count = 0\n for c in s:\n if c in 'aeiou':\n count += 1\n return count", "correct_code": "def count_vowels(s):\n count = 0\n for c in s.lower():\n if c in 'aeiou':\n count += 1\n return count", "bug_description": "Does not handle uppercase vowels", "tests": [ {"name": "lowercase", "code": "assert count_vowels('hello') == 2"}, {"name": "uppercase", "code": "assert count_vowels('HELLO') == 2"}, {"name": "mixed case", "code": "assert count_vowels('HeLLo') == 2"}, {"name": "no vowels", "code": "assert count_vowels('xyz') == 0"}, ], }, { "function_name": "is_palindrome", "buggy_code": "def is_palindrome(s):\n s = s.lower()\n return s == s[::-1]", "correct_code": "def is_palindrome(s):\n s = ''.join(c for c in s.lower() if c.isalnum())\n return s == s[::-1]", "bug_description": "Does not strip non-alphanumeric characters before checking", "tests": [ {"name": "basic palindrome", "code": "assert is_palindrome('racecar') == True"}, {"name": "with spaces", "code": "assert is_palindrome('race car') == False"}, {"name": "with punctuation", "code": "assert is_palindrome('A man, a plan, a canal: Panama') == True"}, {"name": "not palindrome", "code": "assert is_palindrome('hello') == False"}, ], }, { "function_name": "flatten_list", "buggy_code": "def flatten_list(lst):\n result = []\n for item in lst:\n if isinstance(item, list):\n result.append(item)\n else:\n result.append(item)\n return result", "correct_code": "def flatten_list(lst):\n result = []\n for item in lst:\n if isinstance(item, list):\n result.extend(flatten_list(item))\n else:\n result.append(item)\n return result", "bug_description": "Appends nested lists instead of recursively flattening them", "tests": [ {"name": "nested", "code": "assert flatten_list([1, [2, 3], [4, [5]]]) == [1, 2, 3, 4, 5]"}, {"name": "already flat", "code": "assert flatten_list([1, 2, 3]) == [1, 2, 3]"}, {"name": "empty", "code": "assert flatten_list([]) == []"}, {"name": "deep nesting", "code": "assert flatten_list([[[[1]]]]) == [1]"}, ], }, { "function_name": "binary_search", "buggy_code": "def binary_search(arr, target):\n left, right = 0, len(arr) - 1\n while left < right:\n mid = (left + right) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n left = mid\n else:\n right = mid - 1\n return -1", "correct_code": "def binary_search(arr, target):\n left, right = 0, len(arr) - 1\n while left <= right:\n mid = (left + right) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n left = mid + 1\n else:\n right = mid - 1\n return -1", "bug_description": "Uses < instead of <= in while condition, and left=mid instead of left=mid+1 (infinite loop / misses elements)", "tests": [ {"name": "found middle", "code": "assert binary_search([1,2,3,4,5], 3) == 2"}, {"name": "found first", "code": "assert binary_search([1,2,3,4,5], 1) == 0"}, {"name": "found last", "code": "assert binary_search([1,2,3,4,5], 5) == 4"}, {"name": "not found", "code": "assert binary_search([1,2,3,4,5], 6) == -1"}, ], }, { "function_name": "merge_sorted", "buggy_code": "def merge_sorted(a, b):\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] <= b[j]:\n result.append(a[i])\n i += 1\n else:\n result.append(b[j])\n j += 1\n return result", "correct_code": "def merge_sorted(a, b):\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] <= b[j]:\n result.append(a[i])\n i += 1\n else:\n result.append(b[j])\n j += 1\n result.extend(a[i:])\n result.extend(b[j:])\n return result", "bug_description": "Missing the remaining elements after the while loop ends", "tests": [ {"name": "basic merge", "code": "assert merge_sorted([1,3,5], [2,4,6]) == [1,2,3,4,5,6]"}, {"name": "one empty", "code": "assert merge_sorted([], [1,2,3]) == [1,2,3]"}, {"name": "both empty", "code": "assert merge_sorted([], []) == []"}, {"name": "unequal length", "code": "assert merge_sorted([1], [2,3,4]) == [1,2,3,4]"}, ], }, { "function_name": "remove_duplicates", "buggy_code": "def remove_duplicates(lst):\n seen = set()\n result = []\n for item in lst:\n if item in seen:\n result.append(item)\n seen.add(item)\n return result", "correct_code": "def remove_duplicates(lst):\n seen = set()\n result = []\n for item in lst:\n if item not in seen:\n result.append(item)\n seen.add(item)\n return result", "bug_description": "Condition is inverted: keeps duplicates and removes unique items", "tests": [ {"name": "basic dedup", "code": "assert remove_duplicates([1,2,2,3,3,3]) == [1,2,3]"}, {"name": "no duplicates", "code": "assert remove_duplicates([1,2,3]) == [1,2,3]"}, {"name": "all same", "code": "assert remove_duplicates([5,5,5]) == [5]"}, {"name": "empty", "code": "assert remove_duplicates([]) == []"}, ], }, ] class CodeDebugEnvironment: """OpenEnv-compatible environment for code debugging RL. The agent receives a buggy Python function and must propose fixes. Each step runs the proposed fix against test cases and returns multi-signal reward feedback. """ def __init__(self): _load_dataset() self._sessions: Dict[str, CodeDebugState] = {} def reset(self, session_id: str = "", task: Optional[str] = None) -> CodeDebugObservation: """Start a new episode: sample a buggy function. Args: session_id: WebSocket session ID. Auto-generated if empty. task: Optional task name from openenv.yaml (e.g. "debug-add_numbers"). If provided, selects the matching bug. Otherwise picks randomly. Returns: Initial observation with the buggy code and test info. """ if not session_id: session_id = str(uuid.uuid4()) # Select bug by task name or randomly bug = None if task: # Strip "debug-" prefix to get function_name (e.g. "debug-add_numbers" -> "add_numbers") fn_name = task.replace("debug-", "", 1) bug = next((b for b in _BUG_DATASET if b["function_name"] == fn_name), None) if bug is None: bug = random.choice(_BUG_DATASET) # Create state state = CodeDebugState( episode_id=str(uuid.uuid4()), session_id=session_id, attempt=0, max_attempts=6, original_bug=bug["buggy_code"], current_code=bug["buggy_code"], bug_description=bug["bug_description"], function_name=bug["function_name"], tests_passed_history=[], fix_hashes=[], solved=False, ) # Store the bug data in state for test access state._bug_data = bug # type: ignore[attr-defined] self._sessions[session_id] = state # Run tests on the buggy code to show initial failure test_results, passed, total, error_output = run_tests( bug["buggy_code"], bug["tests"] ) feedback = self._build_feedback(bug, test_results, passed, total, is_initial=True) return CodeDebugObservation( buggy_code=bug["buggy_code"], current_code=bug["buggy_code"], error_output=error_output, test_results=test_results, tests_passed=passed, tests_total=total, reward=0.01, # Non-zero initial reward (0.0 is forbidden by Phase 2) done=False, attempt=0, max_attempts=6, feedback=feedback, ) def step(self, action: CodeDebugAction) -> CodeDebugObservation: """Apply the agent's proposed fix and evaluate. Args: action: CodeDebugAction with the proposed fix. Returns: Observation with test results, reward, and feedback. """ session_id = action.session_id if session_id not in self._sessions: return CodeDebugObservation( buggy_code="", current_code="", error_output="Invalid session_id. Call reset() first.", reward=-1.0, done=True, feedback="Error: Invalid session. Please call reset() first.", ) state = self._sessions[session_id] bug = state._bug_data # type: ignore[attr-defined] # Increment attempt state.attempt += 1 # --- Reward computation --- total_reward = 0.0 proposed_fix = action.proposed_fix.strip() # 1. Syntax check is_valid_syntax, syntax_error = check_syntax(proposed_fix) if not is_valid_syntax: state.tests_passed_history.append(0) state.current_code = proposed_fix done = state.attempt >= state.max_attempts feedback = self._build_fix_feedback( bug, [], 0, len(bug["tests"]), syntax_error=syntax_error, attempt=state.attempt, max_attempts=state.max_attempts ) if done: del self._sessions[session_id] return CodeDebugObservation( buggy_code=bug["buggy_code"], current_code=proposed_fix, error_output=syntax_error, test_results=[], tests_passed=0, tests_total=len(bug["tests"]), reward=0.01, # Clamped: syntax error gives minimum reward, not -1.0 done=done, attempt=state.attempt, max_attempts=state.max_attempts, feedback=feedback, ) # 2. Repetition check fix_hash = hashlib.sha256(proposed_fix.encode()).hexdigest() is_repeated = fix_hash in state.fix_hashes state.fix_hashes.append(fix_hash) if is_repeated: total_reward -= 0.5 # 3. Run tests test_results, passed, total, error_output = run_tests( proposed_fix, bug["tests"] ) # 4. Correctness reward if passed == total: total_reward += 2.0 state.solved = True elif state.tests_passed_history: prev_best = max(state.tests_passed_history) if passed > prev_best: total_reward += 0.5 # Progress elif passed <= prev_best and passed > 0: total_reward -= 0.3 # Stagnation elif passed > 0: total_reward += 0.5 # First partial success # 5. Runtime error penalty (0 tests passed, no syntax error) if passed == 0 and is_valid_syntax and error_output: total_reward -= 0.5 state.tests_passed_history.append(passed) state.current_code = proposed_fix # Strictly bound the reward to (0, 1) as required by Phase 2 Deep Validation total_reward = min(max(total_reward, 0.01), 0.99) done = state.solved or state.attempt >= state.max_attempts feedback = self._build_fix_feedback( bug, test_results, passed, total, attempt=state.attempt, max_attempts=state.max_attempts, is_repeated=is_repeated ) if done: # Clean up session if session_id in self._sessions: del self._sessions[session_id] return CodeDebugObservation( buggy_code=bug["buggy_code"], current_code=proposed_fix, error_output=error_output, test_results=test_results, tests_passed=passed, tests_total=total, reward=total_reward, done=done, attempt=state.attempt, max_attempts=state.max_attempts, feedback=feedback, ) def get_state(self, session_id: str) -> Optional[CodeDebugState]: """Get current state for a session. Args: session_id: WebSocket session ID. Returns: Current state or None if session not found. """ return self._sessions.get(session_id) # --- Feedback builders --- def _build_feedback( self, bug: dict, test_results: List[TestResult], passed: int, total: int, is_initial: bool = False ) -> str: lines = [] if is_initial: lines.append(f"## Bug Report: `{bug['function_name']}`") lines.append(f"The function `{bug['function_name']}` has a bug.") lines.append(f"**Current test results:** {passed}/{total} tests passing") lines.append("") lines.append("### Buggy Code:") lines.append(f"```python\n{bug['buggy_code']}\n```") lines.append("") lines.append("### Failing Tests:") for tr in test_results: if not tr.passed: lines.append(f"- ❌ {tr.test_name}: {tr.error_message}") lines.append("") lines.append("Please provide the corrected function.") return "\n".join(lines) def _build_fix_feedback( self, bug: dict, test_results: List[TestResult], passed: int, total: int, syntax_error: str = "", attempt: int = 0, max_attempts: int = 6, is_repeated: bool = False ) -> str: lines = [] lines.append(f"## Attempt {attempt}/{max_attempts}") if syntax_error: lines.append(f"❌ **Syntax Error:** {syntax_error}") lines.append("Please fix the syntax and try again.") return "\n".join(lines) if is_repeated: lines.append("⚠️ **Warning:** You submitted the same fix as before. Try a different approach.") lines.append(f"**Tests:** {passed}/{total} passing") if passed == total: lines.append("✅ **All tests pass! Bug fixed successfully!**") else: lines.append("") lines.append("### Test Results:") for tr in test_results: status = "✅" if tr.passed else "❌" line = f"- {status} {tr.test_name}" if not tr.passed and tr.error_message: line += f": {tr.error_message}" lines.append(line) remaining = max_attempts - attempt if remaining > 0: lines.append(f"\n{remaining} attempts remaining.") return "\n".join(lines)