Spaces:
Sleeping
Sleeping
File size: 3,515 Bytes
cacd58c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | # server/environment.py
from __future__ import annotations
import uuid
from openenv.core.env_server import Environment
from ..models import Action, Observation, State
from .grader import grade
from .tasks import TASK_REGISTRY
class CodeDebugEnvironment(Environment):
"""
Real-world environment: AI agent must fix buggy Python functions.
Episodes are multi-turn: agent iterates until all tests pass or max_steps reached.
"""
def __init__(self):
super().__init__()
self._state = State()
self._current_task = None
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
task_id: str | None = None,
**kwargs,
) -> Observation:
"""
Start a new episode.
- If task_id is None, sample a random task from the registry.
- Always returns a clean Observation with the buggy code.
"""
if task_id is None:
import random
task_id = random.choice(list(TASK_REGISTRY.keys()))
task = TASK_REGISTRY[task_id]
self._current_task = task
self._state = State(
episode_id=str(uuid.uuid4()),
task_id=task_id,
step_count=0,
max_steps=10,
current_score=0.0,
best_score=0.0,
)
return Observation(
task_id=task_id,
buggy_code=task["buggy_code"],
task_description=task["description"],
passed=0,
total=task["num_tests"],
score=0.0,
done=False,
)
def step(
self,
action: Action,
timeout_s: float | None = None,
**kwargs,
) -> Observation:
"""
Execute the agent's patch.
Returns observation with test results and composite reward.
"""
if self._current_task is None:
raise RuntimeError("Call reset() before step()")
self._state.step_count += 1
task = self._current_task
# Grade the submission
grade_result = grade(
submitted_code=action.patch,
task_id=action.task_id,
test_suite=task["test_suite"],
)
# Composite reward:
# 0.5 * correctness + 0.2 * format + 0.2 * cot_bonus + 0.1 * efficiency
r_correct = grade_result["score"] # 0.0–1.0
r_format = 1.0 if grade_result["valid_syntax"] else 0.0
r_cot = 0.2 if (action.think and len(action.think) > 20) else 0.0
r_eff = max(0.0, (10 - self._state.step_count) / 10) * 0.1
reward = 0.5 * r_correct + 0.2 * r_format + r_cot + r_eff
reward = max(0.0, min(1.0, reward))
# Penalty for timeout/crash
if grade_result.get("timed_out"):
reward = max(0.0, reward - 0.3)
done = (r_correct == 1.0) or (self._state.step_count >= self._state.max_steps)
self._state.current_score = reward
self._state.best_score = max(self._state.best_score, reward)
return Observation(
task_id=action.task_id,
buggy_code=action.patch,
task_description=task["description"],
test_results=grade_result["test_results"],
passed=grade_result["passed"],
total=grade_result["total"],
score=reward,
done=done,
error=grade_result.get("error"),
)
@property
def state(self) -> State:
return self._state
|