codesensei-env / env /server /environment.py
vineetshukla.work@gmail.com
fix: task routing by name, remove out-of-range rewards, add grader field to tasks
b64950c
"""
CodeSensei — CodeDebug OpenEnv Environment.
Implements the core Environment with reset(), step(), and state
following the OpenEnv 3-method pattern. Manages episodes, tracks
attempts, computes multi-signal rewards, and detects repeated fixes.
"""
from __future__ import annotations
import hashlib
import json
import os
import random
import uuid
from typing import Dict, List, Optional, Tuple
from env.models import CodeDebugAction, CodeDebugObservation, CodeDebugState, TestResult
from env.server.test_runner import run_tests
from env.server.sandbox import check_syntax
# Load bug dataset
_DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
_BUG_DATASET: List[dict] = []
def _load_dataset():
global _BUG_DATASET
dataset_path = os.path.join(_DATA_DIR, "bug_dataset.json")
if os.path.exists(dataset_path):
with open(dataset_path, "r", encoding="utf-8") as f:
_BUG_DATASET = json.load(f)
else:
# Fallback: built-in minimal dataset for testing
_BUG_DATASET = _get_builtin_dataset()
def _get_builtin_dataset() -> List[dict]:
"""Built-in minimal dataset used when bug_dataset.json is missing."""
return [
{
"function_name": "add_numbers",
"buggy_code": "def add_numbers(a, b):\n return a - b",
"correct_code": "def add_numbers(a, b):\n return a + b",
"bug_description": "Uses subtraction instead of addition",
"tests": [
{"name": "basic addition", "code": "assert add_numbers(2, 3) == 5"},
{"name": "zero addition", "code": "assert add_numbers(0, 0) == 0"},
{"name": "negative addition", "code": "assert add_numbers(-1, 1) == 0"},
],
},
{
"function_name": "find_max",
"buggy_code": "def find_max(lst):\n if not lst:\n return None\n result = lst[0]\n for x in lst:\n if x < result:\n result = x\n return result",
"correct_code": "def find_max(lst):\n if not lst:\n return None\n result = lst[0]\n for x in lst:\n if x > result:\n result = x\n return result",
"bug_description": "Uses < instead of > (finds minimum instead of maximum)",
"tests": [
{"name": "basic max", "code": "assert find_max([1, 3, 2]) == 3"},
{"name": "single element", "code": "assert find_max([5]) == 5"},
{"name": "negative numbers", "code": "assert find_max([-1, -5, -2]) == -1"},
{"name": "empty list", "code": "assert find_max([]) is None"},
],
},
{
"function_name": "reverse_string",
"buggy_code": "def reverse_string(s):\n return s[1:]",
"correct_code": "def reverse_string(s):\n return s[::-1]",
"bug_description": "Slices from index 1 instead of reversing",
"tests": [
{"name": "basic reverse", "code": 'assert reverse_string("hello") == "olleh"'},
{"name": "empty string", "code": 'assert reverse_string("") == ""'},
{"name": "single char", "code": 'assert reverse_string("a") == "a"'},
{"name": "palindrome", "code": 'assert reverse_string("racecar") == "racecar"'},
],
},
{
"function_name": "fibonacci",
"buggy_code": "def fibonacci(n):\n if n <= 0:\n return 0\n if n == 1:\n return 1\n return fibonacci(n - 1) + fibonacci(n - 3)",
"correct_code": "def fibonacci(n):\n if n <= 0:\n return 0\n if n == 1:\n return 1\n return fibonacci(n - 1) + fibonacci(n - 2)",
"bug_description": "Recursive call uses n-3 instead of n-2",
"tests": [
{"name": "fib(0)", "code": "assert fibonacci(0) == 0"},
{"name": "fib(1)", "code": "assert fibonacci(1) == 1"},
{"name": "fib(5)", "code": "assert fibonacci(5) == 5"},
{"name": "fib(10)", "code": "assert fibonacci(10) == 55"},
],
},
{
"function_name": "count_vowels",
"buggy_code": "def count_vowels(s):\n count = 0\n for c in s:\n if c in 'aeiou':\n count += 1\n return count",
"correct_code": "def count_vowels(s):\n count = 0\n for c in s.lower():\n if c in 'aeiou':\n count += 1\n return count",
"bug_description": "Does not handle uppercase vowels",
"tests": [
{"name": "lowercase", "code": "assert count_vowels('hello') == 2"},
{"name": "uppercase", "code": "assert count_vowels('HELLO') == 2"},
{"name": "mixed case", "code": "assert count_vowels('HeLLo') == 2"},
{"name": "no vowels", "code": "assert count_vowels('xyz') == 0"},
],
},
{
"function_name": "is_palindrome",
"buggy_code": "def is_palindrome(s):\n s = s.lower()\n return s == s[::-1]",
"correct_code": "def is_palindrome(s):\n s = ''.join(c for c in s.lower() if c.isalnum())\n return s == s[::-1]",
"bug_description": "Does not strip non-alphanumeric characters before checking",
"tests": [
{"name": "basic palindrome", "code": "assert is_palindrome('racecar') == True"},
{"name": "with spaces", "code": "assert is_palindrome('race car') == False"},
{"name": "with punctuation", "code": "assert is_palindrome('A man, a plan, a canal: Panama') == True"},
{"name": "not palindrome", "code": "assert is_palindrome('hello') == False"},
],
},
{
"function_name": "flatten_list",
"buggy_code": "def flatten_list(lst):\n result = []\n for item in lst:\n if isinstance(item, list):\n result.append(item)\n else:\n result.append(item)\n return result",
"correct_code": "def flatten_list(lst):\n result = []\n for item in lst:\n if isinstance(item, list):\n result.extend(flatten_list(item))\n else:\n result.append(item)\n return result",
"bug_description": "Appends nested lists instead of recursively flattening them",
"tests": [
{"name": "nested", "code": "assert flatten_list([1, [2, 3], [4, [5]]]) == [1, 2, 3, 4, 5]"},
{"name": "already flat", "code": "assert flatten_list([1, 2, 3]) == [1, 2, 3]"},
{"name": "empty", "code": "assert flatten_list([]) == []"},
{"name": "deep nesting", "code": "assert flatten_list([[[[1]]]]) == [1]"},
],
},
{
"function_name": "binary_search",
"buggy_code": "def binary_search(arr, target):\n left, right = 0, len(arr) - 1\n while left < right:\n mid = (left + right) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n left = mid\n else:\n right = mid - 1\n return -1",
"correct_code": "def binary_search(arr, target):\n left, right = 0, len(arr) - 1\n while left <= right:\n mid = (left + right) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n left = mid + 1\n else:\n right = mid - 1\n return -1",
"bug_description": "Uses < instead of <= in while condition, and left=mid instead of left=mid+1 (infinite loop / misses elements)",
"tests": [
{"name": "found middle", "code": "assert binary_search([1,2,3,4,5], 3) == 2"},
{"name": "found first", "code": "assert binary_search([1,2,3,4,5], 1) == 0"},
{"name": "found last", "code": "assert binary_search([1,2,3,4,5], 5) == 4"},
{"name": "not found", "code": "assert binary_search([1,2,3,4,5], 6) == -1"},
],
},
{
"function_name": "merge_sorted",
"buggy_code": "def merge_sorted(a, b):\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] <= b[j]:\n result.append(a[i])\n i += 1\n else:\n result.append(b[j])\n j += 1\n return result",
"correct_code": "def merge_sorted(a, b):\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] <= b[j]:\n result.append(a[i])\n i += 1\n else:\n result.append(b[j])\n j += 1\n result.extend(a[i:])\n result.extend(b[j:])\n return result",
"bug_description": "Missing the remaining elements after the while loop ends",
"tests": [
{"name": "basic merge", "code": "assert merge_sorted([1,3,5], [2,4,6]) == [1,2,3,4,5,6]"},
{"name": "one empty", "code": "assert merge_sorted([], [1,2,3]) == [1,2,3]"},
{"name": "both empty", "code": "assert merge_sorted([], []) == []"},
{"name": "unequal length", "code": "assert merge_sorted([1], [2,3,4]) == [1,2,3,4]"},
],
},
{
"function_name": "remove_duplicates",
"buggy_code": "def remove_duplicates(lst):\n seen = set()\n result = []\n for item in lst:\n if item in seen:\n result.append(item)\n seen.add(item)\n return result",
"correct_code": "def remove_duplicates(lst):\n seen = set()\n result = []\n for item in lst:\n if item not in seen:\n result.append(item)\n seen.add(item)\n return result",
"bug_description": "Condition is inverted: keeps duplicates and removes unique items",
"tests": [
{"name": "basic dedup", "code": "assert remove_duplicates([1,2,2,3,3,3]) == [1,2,3]"},
{"name": "no duplicates", "code": "assert remove_duplicates([1,2,3]) == [1,2,3]"},
{"name": "all same", "code": "assert remove_duplicates([5,5,5]) == [5]"},
{"name": "empty", "code": "assert remove_duplicates([]) == []"},
],
},
]
class CodeDebugEnvironment:
"""OpenEnv-compatible environment for code debugging RL.
The agent receives a buggy Python function and must propose fixes.
Each step runs the proposed fix against test cases and returns
multi-signal reward feedback.
"""
def __init__(self):
_load_dataset()
self._sessions: Dict[str, CodeDebugState] = {}
def reset(self, session_id: str = "", task: Optional[str] = None) -> CodeDebugObservation:
"""Start a new episode: sample a buggy function.
Args:
session_id: WebSocket session ID. Auto-generated if empty.
task: Optional task name from openenv.yaml (e.g. "debug-add_numbers").
If provided, selects the matching bug. Otherwise picks randomly.
Returns:
Initial observation with the buggy code and test info.
"""
if not session_id:
session_id = str(uuid.uuid4())
# Select bug by task name or randomly
bug = None
if task:
# Strip "debug-" prefix to get function_name (e.g. "debug-add_numbers" -> "add_numbers")
fn_name = task.replace("debug-", "", 1)
bug = next((b for b in _BUG_DATASET if b["function_name"] == fn_name), None)
if bug is None:
bug = random.choice(_BUG_DATASET)
# Create state
state = CodeDebugState(
episode_id=str(uuid.uuid4()),
session_id=session_id,
attempt=0,
max_attempts=6,
original_bug=bug["buggy_code"],
current_code=bug["buggy_code"],
bug_description=bug["bug_description"],
function_name=bug["function_name"],
tests_passed_history=[],
fix_hashes=[],
solved=False,
)
# Store the bug data in state for test access
state._bug_data = bug # type: ignore[attr-defined]
self._sessions[session_id] = state
# Run tests on the buggy code to show initial failure
test_results, passed, total, error_output = run_tests(
bug["buggy_code"], bug["tests"]
)
feedback = self._build_feedback(bug, test_results, passed, total, is_initial=True)
return CodeDebugObservation(
buggy_code=bug["buggy_code"],
current_code=bug["buggy_code"],
error_output=error_output,
test_results=test_results,
tests_passed=passed,
tests_total=total,
reward=0.01, # Non-zero initial reward (0.0 is forbidden by Phase 2)
done=False,
attempt=0,
max_attempts=6,
feedback=feedback,
)
def step(self, action: CodeDebugAction) -> CodeDebugObservation:
"""Apply the agent's proposed fix and evaluate.
Args:
action: CodeDebugAction with the proposed fix.
Returns:
Observation with test results, reward, and feedback.
"""
session_id = action.session_id
if session_id not in self._sessions:
return CodeDebugObservation(
buggy_code="",
current_code="",
error_output="Invalid session_id. Call reset() first.",
reward=-1.0,
done=True,
feedback="Error: Invalid session. Please call reset() first.",
)
state = self._sessions[session_id]
bug = state._bug_data # type: ignore[attr-defined]
# Increment attempt
state.attempt += 1
# --- Reward computation ---
total_reward = 0.0
proposed_fix = action.proposed_fix.strip()
# 1. Syntax check
is_valid_syntax, syntax_error = check_syntax(proposed_fix)
if not is_valid_syntax:
state.tests_passed_history.append(0)
state.current_code = proposed_fix
done = state.attempt >= state.max_attempts
feedback = self._build_fix_feedback(
bug, [], 0, len(bug["tests"]),
syntax_error=syntax_error, attempt=state.attempt,
max_attempts=state.max_attempts
)
if done:
del self._sessions[session_id]
return CodeDebugObservation(
buggy_code=bug["buggy_code"],
current_code=proposed_fix,
error_output=syntax_error,
test_results=[],
tests_passed=0,
tests_total=len(bug["tests"]),
reward=0.01, # Clamped: syntax error gives minimum reward, not -1.0
done=done,
attempt=state.attempt,
max_attempts=state.max_attempts,
feedback=feedback,
)
# 2. Repetition check
fix_hash = hashlib.sha256(proposed_fix.encode()).hexdigest()
is_repeated = fix_hash in state.fix_hashes
state.fix_hashes.append(fix_hash)
if is_repeated:
total_reward -= 0.5
# 3. Run tests
test_results, passed, total, error_output = run_tests(
proposed_fix, bug["tests"]
)
# 4. Correctness reward
if passed == total:
total_reward += 2.0
state.solved = True
elif state.tests_passed_history:
prev_best = max(state.tests_passed_history)
if passed > prev_best:
total_reward += 0.5 # Progress
elif passed <= prev_best and passed > 0:
total_reward -= 0.3 # Stagnation
elif passed > 0:
total_reward += 0.5 # First partial success
# 5. Runtime error penalty (0 tests passed, no syntax error)
if passed == 0 and is_valid_syntax and error_output:
total_reward -= 0.5
state.tests_passed_history.append(passed)
state.current_code = proposed_fix
# Strictly bound the reward to (0, 1) as required by Phase 2 Deep Validation
total_reward = min(max(total_reward, 0.01), 0.99)
done = state.solved or state.attempt >= state.max_attempts
feedback = self._build_fix_feedback(
bug, test_results, passed, total,
attempt=state.attempt, max_attempts=state.max_attempts,
is_repeated=is_repeated
)
if done:
# Clean up session
if session_id in self._sessions:
del self._sessions[session_id]
return CodeDebugObservation(
buggy_code=bug["buggy_code"],
current_code=proposed_fix,
error_output=error_output,
test_results=test_results,
tests_passed=passed,
tests_total=total,
reward=total_reward,
done=done,
attempt=state.attempt,
max_attempts=state.max_attempts,
feedback=feedback,
)
def get_state(self, session_id: str) -> Optional[CodeDebugState]:
"""Get current state for a session.
Args:
session_id: WebSocket session ID.
Returns:
Current state or None if session not found.
"""
return self._sessions.get(session_id)
# --- Feedback builders ---
def _build_feedback(
self, bug: dict, test_results: List[TestResult],
passed: int, total: int, is_initial: bool = False
) -> str:
lines = []
if is_initial:
lines.append(f"## Bug Report: `{bug['function_name']}`")
lines.append(f"The function `{bug['function_name']}` has a bug.")
lines.append(f"**Current test results:** {passed}/{total} tests passing")
lines.append("")
lines.append("### Buggy Code:")
lines.append(f"```python\n{bug['buggy_code']}\n```")
lines.append("")
lines.append("### Failing Tests:")
for tr in test_results:
if not tr.passed:
lines.append(f"- ❌ {tr.test_name}: {tr.error_message}")
lines.append("")
lines.append("Please provide the corrected function.")
return "\n".join(lines)
def _build_fix_feedback(
self, bug: dict, test_results: List[TestResult],
passed: int, total: int,
syntax_error: str = "", attempt: int = 0,
max_attempts: int = 6, is_repeated: bool = False
) -> str:
lines = []
lines.append(f"## Attempt {attempt}/{max_attempts}")
if syntax_error:
lines.append(f"❌ **Syntax Error:** {syntax_error}")
lines.append("Please fix the syntax and try again.")
return "\n".join(lines)
if is_repeated:
lines.append("⚠️ **Warning:** You submitted the same fix as before. Try a different approach.")
lines.append(f"**Tests:** {passed}/{total} passing")
if passed == total:
lines.append("✅ **All tests pass! Bug fixed successfully!**")
else:
lines.append("")
lines.append("### Test Results:")
for tr in test_results:
status = "✅" if tr.passed else "❌"
line = f"- {status} {tr.test_name}"
if not tr.passed and tr.error_message:
line += f": {tr.error_message}"
lines.append(line)
remaining = max_attempts - attempt
if remaining > 0:
lines.append(f"\n{remaining} attempts remaining.")
return "\n".join(lines)