code-debug-env / server /environment.py
luciferai-devil's picture
Upload folder using huggingface_hub
cacd58c verified
# server/environment.py
from __future__ import annotations
import uuid
from openenv.core.env_server import Environment
from ..models import Action, Observation, State
from .grader import grade
from .tasks import TASK_REGISTRY
class CodeDebugEnvironment(Environment):
"""
Real-world environment: AI agent must fix buggy Python functions.
Episodes are multi-turn: agent iterates until all tests pass or max_steps reached.
"""
def __init__(self):
super().__init__()
self._state = State()
self._current_task = None
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
task_id: str | None = None,
**kwargs,
) -> Observation:
"""
Start a new episode.
- If task_id is None, sample a random task from the registry.
- Always returns a clean Observation with the buggy code.
"""
if task_id is None:
import random
task_id = random.choice(list(TASK_REGISTRY.keys()))
task = TASK_REGISTRY[task_id]
self._current_task = task
self._state = State(
episode_id=str(uuid.uuid4()),
task_id=task_id,
step_count=0,
max_steps=10,
current_score=0.0,
best_score=0.0,
)
return Observation(
task_id=task_id,
buggy_code=task["buggy_code"],
task_description=task["description"],
passed=0,
total=task["num_tests"],
score=0.0,
done=False,
)
def step(
self,
action: Action,
timeout_s: float | None = None,
**kwargs,
) -> Observation:
"""
Execute the agent's patch.
Returns observation with test results and composite reward.
"""
if self._current_task is None:
raise RuntimeError("Call reset() before step()")
self._state.step_count += 1
task = self._current_task
# Grade the submission
grade_result = grade(
submitted_code=action.patch,
task_id=action.task_id,
test_suite=task["test_suite"],
)
# Composite reward:
# 0.5 * correctness + 0.2 * format + 0.2 * cot_bonus + 0.1 * efficiency
r_correct = grade_result["score"] # 0.0–1.0
r_format = 1.0 if grade_result["valid_syntax"] else 0.0
r_cot = 0.2 if (action.think and len(action.think) > 20) else 0.0
r_eff = max(0.0, (10 - self._state.step_count) / 10) * 0.1
reward = 0.5 * r_correct + 0.2 * r_format + r_cot + r_eff
reward = max(0.0, min(1.0, reward))
# Penalty for timeout/crash
if grade_result.get("timed_out"):
reward = max(0.0, reward - 0.3)
done = (r_correct == 1.0) or (self._state.step_count >= self._state.max_steps)
self._state.current_score = reward
self._state.best_score = max(self._state.best_score, reward)
return Observation(
task_id=action.task_id,
buggy_code=action.patch,
task_description=task["description"],
test_results=grade_result["test_results"],
passed=grade_result["passed"],
total=grade_result["total"],
score=reward,
done=done,
error=grade_result.get("error"),
)
@property
def state(self) -> State:
return self._state