Spaces:
Sleeping
Sleeping
vineetshukla.work@gmail.com
fix: task routing by name, remove out-of-range rewards, add grader field to tasks
b64950c | """ | |
| CodeSensei — CodeDebug OpenEnv Environment. | |
| Implements the core Environment with reset(), step(), and state | |
| following the OpenEnv 3-method pattern. Manages episodes, tracks | |
| attempts, computes multi-signal rewards, and detects repeated fixes. | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import os | |
| import random | |
| import uuid | |
| from typing import Dict, List, Optional, Tuple | |
| from env.models import CodeDebugAction, CodeDebugObservation, CodeDebugState, TestResult | |
| from env.server.test_runner import run_tests | |
| from env.server.sandbox import check_syntax | |
| # Load bug dataset | |
| _DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") | |
| _BUG_DATASET: List[dict] = [] | |
| def _load_dataset(): | |
| global _BUG_DATASET | |
| dataset_path = os.path.join(_DATA_DIR, "bug_dataset.json") | |
| if os.path.exists(dataset_path): | |
| with open(dataset_path, "r", encoding="utf-8") as f: | |
| _BUG_DATASET = json.load(f) | |
| else: | |
| # Fallback: built-in minimal dataset for testing | |
| _BUG_DATASET = _get_builtin_dataset() | |
| def _get_builtin_dataset() -> List[dict]: | |
| """Built-in minimal dataset used when bug_dataset.json is missing.""" | |
| return [ | |
| { | |
| "function_name": "add_numbers", | |
| "buggy_code": "def add_numbers(a, b):\n return a - b", | |
| "correct_code": "def add_numbers(a, b):\n return a + b", | |
| "bug_description": "Uses subtraction instead of addition", | |
| "tests": [ | |
| {"name": "basic addition", "code": "assert add_numbers(2, 3) == 5"}, | |
| {"name": "zero addition", "code": "assert add_numbers(0, 0) == 0"}, | |
| {"name": "negative addition", "code": "assert add_numbers(-1, 1) == 0"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "find_max", | |
| "buggy_code": "def find_max(lst):\n if not lst:\n return None\n result = lst[0]\n for x in lst:\n if x < result:\n result = x\n return result", | |
| "correct_code": "def find_max(lst):\n if not lst:\n return None\n result = lst[0]\n for x in lst:\n if x > result:\n result = x\n return result", | |
| "bug_description": "Uses < instead of > (finds minimum instead of maximum)", | |
| "tests": [ | |
| {"name": "basic max", "code": "assert find_max([1, 3, 2]) == 3"}, | |
| {"name": "single element", "code": "assert find_max([5]) == 5"}, | |
| {"name": "negative numbers", "code": "assert find_max([-1, -5, -2]) == -1"}, | |
| {"name": "empty list", "code": "assert find_max([]) is None"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "reverse_string", | |
| "buggy_code": "def reverse_string(s):\n return s[1:]", | |
| "correct_code": "def reverse_string(s):\n return s[::-1]", | |
| "bug_description": "Slices from index 1 instead of reversing", | |
| "tests": [ | |
| {"name": "basic reverse", "code": 'assert reverse_string("hello") == "olleh"'}, | |
| {"name": "empty string", "code": 'assert reverse_string("") == ""'}, | |
| {"name": "single char", "code": 'assert reverse_string("a") == "a"'}, | |
| {"name": "palindrome", "code": 'assert reverse_string("racecar") == "racecar"'}, | |
| ], | |
| }, | |
| { | |
| "function_name": "fibonacci", | |
| "buggy_code": "def fibonacci(n):\n if n <= 0:\n return 0\n if n == 1:\n return 1\n return fibonacci(n - 1) + fibonacci(n - 3)", | |
| "correct_code": "def fibonacci(n):\n if n <= 0:\n return 0\n if n == 1:\n return 1\n return fibonacci(n - 1) + fibonacci(n - 2)", | |
| "bug_description": "Recursive call uses n-3 instead of n-2", | |
| "tests": [ | |
| {"name": "fib(0)", "code": "assert fibonacci(0) == 0"}, | |
| {"name": "fib(1)", "code": "assert fibonacci(1) == 1"}, | |
| {"name": "fib(5)", "code": "assert fibonacci(5) == 5"}, | |
| {"name": "fib(10)", "code": "assert fibonacci(10) == 55"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "count_vowels", | |
| "buggy_code": "def count_vowels(s):\n count = 0\n for c in s:\n if c in 'aeiou':\n count += 1\n return count", | |
| "correct_code": "def count_vowels(s):\n count = 0\n for c in s.lower():\n if c in 'aeiou':\n count += 1\n return count", | |
| "bug_description": "Does not handle uppercase vowels", | |
| "tests": [ | |
| {"name": "lowercase", "code": "assert count_vowels('hello') == 2"}, | |
| {"name": "uppercase", "code": "assert count_vowels('HELLO') == 2"}, | |
| {"name": "mixed case", "code": "assert count_vowels('HeLLo') == 2"}, | |
| {"name": "no vowels", "code": "assert count_vowels('xyz') == 0"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "is_palindrome", | |
| "buggy_code": "def is_palindrome(s):\n s = s.lower()\n return s == s[::-1]", | |
| "correct_code": "def is_palindrome(s):\n s = ''.join(c for c in s.lower() if c.isalnum())\n return s == s[::-1]", | |
| "bug_description": "Does not strip non-alphanumeric characters before checking", | |
| "tests": [ | |
| {"name": "basic palindrome", "code": "assert is_palindrome('racecar') == True"}, | |
| {"name": "with spaces", "code": "assert is_palindrome('race car') == False"}, | |
| {"name": "with punctuation", "code": "assert is_palindrome('A man, a plan, a canal: Panama') == True"}, | |
| {"name": "not palindrome", "code": "assert is_palindrome('hello') == False"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "flatten_list", | |
| "buggy_code": "def flatten_list(lst):\n result = []\n for item in lst:\n if isinstance(item, list):\n result.append(item)\n else:\n result.append(item)\n return result", | |
| "correct_code": "def flatten_list(lst):\n result = []\n for item in lst:\n if isinstance(item, list):\n result.extend(flatten_list(item))\n else:\n result.append(item)\n return result", | |
| "bug_description": "Appends nested lists instead of recursively flattening them", | |
| "tests": [ | |
| {"name": "nested", "code": "assert flatten_list([1, [2, 3], [4, [5]]]) == [1, 2, 3, 4, 5]"}, | |
| {"name": "already flat", "code": "assert flatten_list([1, 2, 3]) == [1, 2, 3]"}, | |
| {"name": "empty", "code": "assert flatten_list([]) == []"}, | |
| {"name": "deep nesting", "code": "assert flatten_list([[[[1]]]]) == [1]"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "binary_search", | |
| "buggy_code": "def binary_search(arr, target):\n left, right = 0, len(arr) - 1\n while left < right:\n mid = (left + right) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n left = mid\n else:\n right = mid - 1\n return -1", | |
| "correct_code": "def binary_search(arr, target):\n left, right = 0, len(arr) - 1\n while left <= right:\n mid = (left + right) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n left = mid + 1\n else:\n right = mid - 1\n return -1", | |
| "bug_description": "Uses < instead of <= in while condition, and left=mid instead of left=mid+1 (infinite loop / misses elements)", | |
| "tests": [ | |
| {"name": "found middle", "code": "assert binary_search([1,2,3,4,5], 3) == 2"}, | |
| {"name": "found first", "code": "assert binary_search([1,2,3,4,5], 1) == 0"}, | |
| {"name": "found last", "code": "assert binary_search([1,2,3,4,5], 5) == 4"}, | |
| {"name": "not found", "code": "assert binary_search([1,2,3,4,5], 6) == -1"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "merge_sorted", | |
| "buggy_code": "def merge_sorted(a, b):\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] <= b[j]:\n result.append(a[i])\n i += 1\n else:\n result.append(b[j])\n j += 1\n return result", | |
| "correct_code": "def merge_sorted(a, b):\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] <= b[j]:\n result.append(a[i])\n i += 1\n else:\n result.append(b[j])\n j += 1\n result.extend(a[i:])\n result.extend(b[j:])\n return result", | |
| "bug_description": "Missing the remaining elements after the while loop ends", | |
| "tests": [ | |
| {"name": "basic merge", "code": "assert merge_sorted([1,3,5], [2,4,6]) == [1,2,3,4,5,6]"}, | |
| {"name": "one empty", "code": "assert merge_sorted([], [1,2,3]) == [1,2,3]"}, | |
| {"name": "both empty", "code": "assert merge_sorted([], []) == []"}, | |
| {"name": "unequal length", "code": "assert merge_sorted([1], [2,3,4]) == [1,2,3,4]"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "remove_duplicates", | |
| "buggy_code": "def remove_duplicates(lst):\n seen = set()\n result = []\n for item in lst:\n if item in seen:\n result.append(item)\n seen.add(item)\n return result", | |
| "correct_code": "def remove_duplicates(lst):\n seen = set()\n result = []\n for item in lst:\n if item not in seen:\n result.append(item)\n seen.add(item)\n return result", | |
| "bug_description": "Condition is inverted: keeps duplicates and removes unique items", | |
| "tests": [ | |
| {"name": "basic dedup", "code": "assert remove_duplicates([1,2,2,3,3,3]) == [1,2,3]"}, | |
| {"name": "no duplicates", "code": "assert remove_duplicates([1,2,3]) == [1,2,3]"}, | |
| {"name": "all same", "code": "assert remove_duplicates([5,5,5]) == [5]"}, | |
| {"name": "empty", "code": "assert remove_duplicates([]) == []"}, | |
| ], | |
| }, | |
| ] | |
| class CodeDebugEnvironment: | |
| """OpenEnv-compatible environment for code debugging RL. | |
| The agent receives a buggy Python function and must propose fixes. | |
| Each step runs the proposed fix against test cases and returns | |
| multi-signal reward feedback. | |
| """ | |
| def __init__(self): | |
| _load_dataset() | |
| self._sessions: Dict[str, CodeDebugState] = {} | |
| def reset(self, session_id: str = "", task: Optional[str] = None) -> CodeDebugObservation: | |
| """Start a new episode: sample a buggy function. | |
| Args: | |
| session_id: WebSocket session ID. Auto-generated if empty. | |
| task: Optional task name from openenv.yaml (e.g. "debug-add_numbers"). | |
| If provided, selects the matching bug. Otherwise picks randomly. | |
| Returns: | |
| Initial observation with the buggy code and test info. | |
| """ | |
| if not session_id: | |
| session_id = str(uuid.uuid4()) | |
| # Select bug by task name or randomly | |
| bug = None | |
| if task: | |
| # Strip "debug-" prefix to get function_name (e.g. "debug-add_numbers" -> "add_numbers") | |
| fn_name = task.replace("debug-", "", 1) | |
| bug = next((b for b in _BUG_DATASET if b["function_name"] == fn_name), None) | |
| if bug is None: | |
| bug = random.choice(_BUG_DATASET) | |
| # Create state | |
| state = CodeDebugState( | |
| episode_id=str(uuid.uuid4()), | |
| session_id=session_id, | |
| attempt=0, | |
| max_attempts=6, | |
| original_bug=bug["buggy_code"], | |
| current_code=bug["buggy_code"], | |
| bug_description=bug["bug_description"], | |
| function_name=bug["function_name"], | |
| tests_passed_history=[], | |
| fix_hashes=[], | |
| solved=False, | |
| ) | |
| # Store the bug data in state for test access | |
| state._bug_data = bug # type: ignore[attr-defined] | |
| self._sessions[session_id] = state | |
| # Run tests on the buggy code to show initial failure | |
| test_results, passed, total, error_output = run_tests( | |
| bug["buggy_code"], bug["tests"] | |
| ) | |
| feedback = self._build_feedback(bug, test_results, passed, total, is_initial=True) | |
| return CodeDebugObservation( | |
| buggy_code=bug["buggy_code"], | |
| current_code=bug["buggy_code"], | |
| error_output=error_output, | |
| test_results=test_results, | |
| tests_passed=passed, | |
| tests_total=total, | |
| reward=0.01, # Non-zero initial reward (0.0 is forbidden by Phase 2) | |
| done=False, | |
| attempt=0, | |
| max_attempts=6, | |
| feedback=feedback, | |
| ) | |
| def step(self, action: CodeDebugAction) -> CodeDebugObservation: | |
| """Apply the agent's proposed fix and evaluate. | |
| Args: | |
| action: CodeDebugAction with the proposed fix. | |
| Returns: | |
| Observation with test results, reward, and feedback. | |
| """ | |
| session_id = action.session_id | |
| if session_id not in self._sessions: | |
| return CodeDebugObservation( | |
| buggy_code="", | |
| current_code="", | |
| error_output="Invalid session_id. Call reset() first.", | |
| reward=-1.0, | |
| done=True, | |
| feedback="Error: Invalid session. Please call reset() first.", | |
| ) | |
| state = self._sessions[session_id] | |
| bug = state._bug_data # type: ignore[attr-defined] | |
| # Increment attempt | |
| state.attempt += 1 | |
| # --- Reward computation --- | |
| total_reward = 0.0 | |
| proposed_fix = action.proposed_fix.strip() | |
| # 1. Syntax check | |
| is_valid_syntax, syntax_error = check_syntax(proposed_fix) | |
| if not is_valid_syntax: | |
| state.tests_passed_history.append(0) | |
| state.current_code = proposed_fix | |
| done = state.attempt >= state.max_attempts | |
| feedback = self._build_fix_feedback( | |
| bug, [], 0, len(bug["tests"]), | |
| syntax_error=syntax_error, attempt=state.attempt, | |
| max_attempts=state.max_attempts | |
| ) | |
| if done: | |
| del self._sessions[session_id] | |
| return CodeDebugObservation( | |
| buggy_code=bug["buggy_code"], | |
| current_code=proposed_fix, | |
| error_output=syntax_error, | |
| test_results=[], | |
| tests_passed=0, | |
| tests_total=len(bug["tests"]), | |
| reward=0.01, # Clamped: syntax error gives minimum reward, not -1.0 | |
| done=done, | |
| attempt=state.attempt, | |
| max_attempts=state.max_attempts, | |
| feedback=feedback, | |
| ) | |
| # 2. Repetition check | |
| fix_hash = hashlib.sha256(proposed_fix.encode()).hexdigest() | |
| is_repeated = fix_hash in state.fix_hashes | |
| state.fix_hashes.append(fix_hash) | |
| if is_repeated: | |
| total_reward -= 0.5 | |
| # 3. Run tests | |
| test_results, passed, total, error_output = run_tests( | |
| proposed_fix, bug["tests"] | |
| ) | |
| # 4. Correctness reward | |
| if passed == total: | |
| total_reward += 2.0 | |
| state.solved = True | |
| elif state.tests_passed_history: | |
| prev_best = max(state.tests_passed_history) | |
| if passed > prev_best: | |
| total_reward += 0.5 # Progress | |
| elif passed <= prev_best and passed > 0: | |
| total_reward -= 0.3 # Stagnation | |
| elif passed > 0: | |
| total_reward += 0.5 # First partial success | |
| # 5. Runtime error penalty (0 tests passed, no syntax error) | |
| if passed == 0 and is_valid_syntax and error_output: | |
| total_reward -= 0.5 | |
| state.tests_passed_history.append(passed) | |
| state.current_code = proposed_fix | |
| # Strictly bound the reward to (0, 1) as required by Phase 2 Deep Validation | |
| total_reward = min(max(total_reward, 0.01), 0.99) | |
| done = state.solved or state.attempt >= state.max_attempts | |
| feedback = self._build_fix_feedback( | |
| bug, test_results, passed, total, | |
| attempt=state.attempt, max_attempts=state.max_attempts, | |
| is_repeated=is_repeated | |
| ) | |
| if done: | |
| # Clean up session | |
| if session_id in self._sessions: | |
| del self._sessions[session_id] | |
| return CodeDebugObservation( | |
| buggy_code=bug["buggy_code"], | |
| current_code=proposed_fix, | |
| error_output=error_output, | |
| test_results=test_results, | |
| tests_passed=passed, | |
| tests_total=total, | |
| reward=total_reward, | |
| done=done, | |
| attempt=state.attempt, | |
| max_attempts=state.max_attempts, | |
| feedback=feedback, | |
| ) | |
| def get_state(self, session_id: str) -> Optional[CodeDebugState]: | |
| """Get current state for a session. | |
| Args: | |
| session_id: WebSocket session ID. | |
| Returns: | |
| Current state or None if session not found. | |
| """ | |
| return self._sessions.get(session_id) | |
| # --- Feedback builders --- | |
| def _build_feedback( | |
| self, bug: dict, test_results: List[TestResult], | |
| passed: int, total: int, is_initial: bool = False | |
| ) -> str: | |
| lines = [] | |
| if is_initial: | |
| lines.append(f"## Bug Report: `{bug['function_name']}`") | |
| lines.append(f"The function `{bug['function_name']}` has a bug.") | |
| lines.append(f"**Current test results:** {passed}/{total} tests passing") | |
| lines.append("") | |
| lines.append("### Buggy Code:") | |
| lines.append(f"```python\n{bug['buggy_code']}\n```") | |
| lines.append("") | |
| lines.append("### Failing Tests:") | |
| for tr in test_results: | |
| if not tr.passed: | |
| lines.append(f"- ❌ {tr.test_name}: {tr.error_message}") | |
| lines.append("") | |
| lines.append("Please provide the corrected function.") | |
| return "\n".join(lines) | |
| def _build_fix_feedback( | |
| self, bug: dict, test_results: List[TestResult], | |
| passed: int, total: int, | |
| syntax_error: str = "", attempt: int = 0, | |
| max_attempts: int = 6, is_repeated: bool = False | |
| ) -> str: | |
| lines = [] | |
| lines.append(f"## Attempt {attempt}/{max_attempts}") | |
| if syntax_error: | |
| lines.append(f"❌ **Syntax Error:** {syntax_error}") | |
| lines.append("Please fix the syntax and try again.") | |
| return "\n".join(lines) | |
| if is_repeated: | |
| lines.append("⚠️ **Warning:** You submitted the same fix as before. Try a different approach.") | |
| lines.append(f"**Tests:** {passed}/{total} passing") | |
| if passed == total: | |
| lines.append("✅ **All tests pass! Bug fixed successfully!**") | |
| else: | |
| lines.append("") | |
| lines.append("### Test Results:") | |
| for tr in test_results: | |
| status = "✅" if tr.passed else "❌" | |
| line = f"- {status} {tr.test_name}" | |
| if not tr.passed and tr.error_message: | |
| line += f": {tr.error_message}" | |
| lines.append(line) | |
| remaining = max_attempts - attempt | |
| if remaining > 0: | |
| lines.append(f"\n{remaining} attempts remaining.") | |
| return "\n".join(lines) | |