File size: 6,853 Bytes
03a907a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import uuid as _uuid_module
import os
from rl_code_fix_env.dataset.loader import get_hardcoded_task
from rl_code_fix_env.dataset.swebench_adapter import get_swebench_task
from rl_code_fix_env.src.reward.reward import compute_reward
from rl_code_fix_env.src.trace.tracer import TraceCollector
from rl_code_fix_env.src.sandbox.patcher import apply_patch
from rl_code_fix_env.src.sandbox.execution import run_test_file


class CodeEnv:
    def __init__(self, max_steps: int = 10):
        self._data: dict = {}
        self.steps: int = 0
        self.max_steps: int = max_steps
        self.episode_id: str = str(_uuid_module.uuid4())
        self.tracer: TraceCollector = TraceCollector()
        self._state: dict = {}
        self.task_source: str = (os.getenv("TASK_SOURCE", "swebench") or "swebench").strip().lower()

    def reset(self, difficulty: str = "easy") -> dict:
        """

        Load a task by difficulty and initialise episode state.

        Returns observation dict.

        """
        self.steps = 0
        self.episode_id = str(_uuid_module.uuid4())
        self.tracer = TraceCollector()

        startup_log = None
        if self.task_source in {"swebench", "swebench_lite", "swebench-lite"}:
            try:
                self._data = get_swebench_task(difficulty)
            except Exception as exc:
                allow_fallback = (
                    os.getenv("SWEBENCH_FALLBACK_LOCAL", "1").strip().lower()
                    in {"1", "true", "yes"}
                )
                if not allow_fallback:
                    raise
                startup_log = f"SWE-bench unavailable, fell back to local task: {exc}"
                self._data = get_hardcoded_task(difficulty)
        else:
            self._data = get_hardcoded_task(difficulty)

        self._state = {
            "code":       self._data["code"],
            "test_path":  self._data["tests"],
            "workspace":  self._data["problem_dir"],
            "logs":       startup_log,
            "test_score": 0.0,
            "prev_test_score": 0.0,  # Track previous score for regression penalty
            "passed":     0,
            "total":      1,
        }
        return self._get_obs()

    def step(self, action: dict):
        """

        Execute one action.



        action = {

            "type": "apply_patch" | "run_tests" | "get_logs",

            "payload": optional str

        }



        Returns:

            (obs_dict, reward, done, info)

        """
        prev_state = self._state.copy()
        prev_test_score = self._state.get("test_score", 0.0)

        action_type = action.get("type", "")
        payload     = action.get("payload") or ""

        if action_type == "apply_patch":
            patched_code, _ = apply_patch(self._state["code"], payload)
            self._state["code"] = patched_code
            # FIX: count apply_patch as a step so the budget is tracked correctly.
            # Previously only run_tests incremented steps, meaning MAX_STEPS was
            # never reached inside the env and done was never set by the env itself.
            self.steps += 1

        elif action_type == "run_tests":
            # FIX: removed the extra self.steps += 1 here.
            # Steps now increment on every meaningful action (apply_patch),
            # not only on test runs. This ensures the step budget is consumed
            # correctly and `done` fires at the right time.
            passed, logs = run_test_file(
                code_file=self._state["code"],
                test_file=self._state["test_path"],
                workspace_dir=self._state["workspace"],
            )
            
            # Parse test counts from logs if available (format: [TEST_COUNTS] passed=X total=Y)
            import re
            test_counts_match = re.search(r'\[TEST_COUNTS\]\s+passed=(\d+)\s+total=(\d+)', logs)
            if test_counts_match:
                passed_count = int(test_counts_match.group(1))
                total_count = int(test_counts_match.group(2))
                self._state["passed"] = passed_count
                self._state["total"] = max(total_count, 1)
                # Calculate partial score: passed/total — clamped to (0, 1) open interval
                # Validator rejects exact 0.0 and 1.0
                _EPS = 1e-6
                raw_score = passed_count / max(total_count, 1)
                self._state["test_score"] = max(_EPS, min(1.0 - _EPS, raw_score))
            else:
                # Fallback to binary scoring if counts not found
                # Use epsilon-clamped values — validator rejects exact 0.0 and 1.0
                _EPS = 1e-6
                self._state["passed"] = 1 if passed else 0
                self._state["total"] = 1
                self._state["test_score"] = (1.0 - _EPS) if passed else _EPS
            
            self._state["logs"] = logs

        elif action_type == "get_logs":
            self._state["logs"] = self._state.get("logs") or "No logs yet. Run tests first."

        # Record trace step
        self.tracer.logs(
            state=prev_state,
            action=action,
            output=self._state,
        )

        done = (
            self._state["passed"] >= self._state["total"]  # all tests pass
            or self.steps >= self.max_steps                 # step budget exhausted
        )
        if action_type == "apply_patch":
            self._state["last_action_empty"] = not (payload and payload.strip())
        last_action_empty = self._state.get("last_action_empty", False)
        try:
            reward = compute_reward(
                test_score=self._state["test_score"],
                trace_obj=self.tracer.steps,
                code=self._state["code"],
                steps_taken=self.steps,
                max_steps=self.max_steps,
                prev_test_score=prev_test_score,  # Pass for regression penalty
                last_action_empty=last_action_empty,
            )
            if self._state["passed"] >= self._state["total"]:
                # 1.0 is rejected by validator — use highest allowed value
                reward = 1.0 - 1e-6
        except Exception:
            _EPS = 1e-6
            reward = max(_EPS, min(1.0 - _EPS, float(self._state.get("test_score", _EPS))))

        return self._get_obs(), float(reward), done, {}

    def _get_obs(self) -> dict:
        return {
            "code":        self._state.get("code", ""),
            "logs":        self._state.get("logs"),
            "test_score":  float(self._state.get("test_score", 0.0)),
            "total_tests": int(self._state.get("total", 1)),
            "steps":       self.steps,
        }