# server/environment.py from __future__ import annotations import uuid from openenv.core.env_server import Environment from ..models import Action, Observation, State from .grader import grade from .tasks import TASK_REGISTRY class CodeDebugEnvironment(Environment): """ Real-world environment: AI agent must fix buggy Python functions. Episodes are multi-turn: agent iterates until all tests pass or max_steps reached. """ def __init__(self): super().__init__() self._state = State() self._current_task = None def reset( self, seed: int | None = None, episode_id: str | None = None, task_id: str | None = None, **kwargs, ) -> Observation: """ Start a new episode. - If task_id is None, sample a random task from the registry. - Always returns a clean Observation with the buggy code. """ if task_id is None: import random task_id = random.choice(list(TASK_REGISTRY.keys())) task = TASK_REGISTRY[task_id] self._current_task = task self._state = State( episode_id=str(uuid.uuid4()), task_id=task_id, step_count=0, max_steps=10, current_score=0.0, best_score=0.0, ) return Observation( task_id=task_id, buggy_code=task["buggy_code"], task_description=task["description"], passed=0, total=task["num_tests"], score=0.0, done=False, ) def step( self, action: Action, timeout_s: float | None = None, **kwargs, ) -> Observation: """ Execute the agent's patch. Returns observation with test results and composite reward. """ if self._current_task is None: raise RuntimeError("Call reset() before step()") self._state.step_count += 1 task = self._current_task # Grade the submission grade_result = grade( submitted_code=action.patch, task_id=action.task_id, test_suite=task["test_suite"], ) # Composite reward: # 0.5 * correctness + 0.2 * format + 0.2 * cot_bonus + 0.1 * efficiency r_correct = grade_result["score"] # 0.0–1.0 r_format = 1.0 if grade_result["valid_syntax"] else 0.0 r_cot = 0.2 if (action.think and len(action.think) > 20) else 0.0 r_eff = max(0.0, (10 - self._state.step_count) / 10) * 0.1 reward = 0.5 * r_correct + 0.2 * r_format + r_cot + r_eff reward = max(0.0, min(1.0, reward)) # Penalty for timeout/crash if grade_result.get("timed_out"): reward = max(0.0, reward - 0.3) done = (r_correct == 1.0) or (self._state.step_count >= self._state.max_steps) self._state.current_score = reward self._state.best_score = max(self._state.best_score, reward) return Observation( task_id=action.task_id, buggy_code=action.patch, task_description=task["description"], test_results=grade_result["test_results"], passed=grade_result["passed"], total=grade_result["total"], score=reward, done=done, error=grade_result.get("error"), ) @property def state(self) -> State: return self._state