File size: 3,515 Bytes
cacd58c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# server/environment.py
from __future__ import annotations
import uuid
from openenv.core.env_server import Environment
from ..models import Action, Observation, State
from .grader import grade
from .tasks import TASK_REGISTRY


class CodeDebugEnvironment(Environment):
    """
    Real-world environment: AI agent must fix buggy Python functions.
    Episodes are multi-turn: agent iterates until all tests pass or max_steps reached.
    """

    def __init__(self):
        super().__init__()
        self._state = State()
        self._current_task = None

    def reset(
        self,
        seed: int | None = None,
        episode_id: str | None = None,
        task_id: str | None = None,
        **kwargs,
    ) -> Observation:
        """
        Start a new episode.
        - If task_id is None, sample a random task from the registry.
        - Always returns a clean Observation with the buggy code.
        """
        if task_id is None:
            import random
            task_id = random.choice(list(TASK_REGISTRY.keys()))

        task = TASK_REGISTRY[task_id]
        self._current_task = task
        self._state = State(
            episode_id=str(uuid.uuid4()),
            task_id=task_id,
            step_count=0,
            max_steps=10,
            current_score=0.0,
            best_score=0.0,
        )

        return Observation(
            task_id=task_id,
            buggy_code=task["buggy_code"],
            task_description=task["description"],
            passed=0,
            total=task["num_tests"],
            score=0.0,
            done=False,
        )

    def step(
        self,
        action: Action,
        timeout_s: float | None = None,
        **kwargs,
    ) -> Observation:
        """
        Execute the agent's patch.
        Returns observation with test results and composite reward.
        """
        if self._current_task is None:
            raise RuntimeError("Call reset() before step()")

        self._state.step_count += 1
        task = self._current_task

        # Grade the submission
        grade_result = grade(
            submitted_code=action.patch,
            task_id=action.task_id,
            test_suite=task["test_suite"],
        )

        # Composite reward:
        # 0.5 * correctness + 0.2 * format + 0.2 * cot_bonus + 0.1 * efficiency
        r_correct = grade_result["score"]          # 0.0–1.0
        r_format  = 1.0 if grade_result["valid_syntax"] else 0.0
        r_cot     = 0.2 if (action.think and len(action.think) > 20) else 0.0
        r_eff     = max(0.0, (10 - self._state.step_count) / 10) * 0.1

        reward = 0.5 * r_correct + 0.2 * r_format + r_cot + r_eff
        reward = max(0.0, min(1.0, reward))

        # Penalty for timeout/crash
        if grade_result.get("timed_out"):
            reward = max(0.0, reward - 0.3)

        done = (r_correct == 1.0) or (self._state.step_count >= self._state.max_steps)

        self._state.current_score = reward
        self._state.best_score = max(self._state.best_score, reward)

        return Observation(
            task_id=action.task_id,
            buggy_code=action.patch,
            task_description=task["description"],
            test_results=grade_result["test_results"],
            passed=grade_result["passed"],
            total=grade_result["total"],
            score=reward,
            done=done,
            error=grade_result.get("error"),
        )

    @property
    def state(self) -> State:
        return self._state