Spaces:
Running
Running
| import uuid as _uuid_module | |
| import os | |
| from rl_code_fix_env.dataset.loader import get_hardcoded_task | |
| from rl_code_fix_env.dataset.swebench_adapter import get_swebench_task | |
| from rl_code_fix_env.src.reward.reward import compute_reward | |
| from rl_code_fix_env.src.trace.tracer import TraceCollector | |
| from rl_code_fix_env.src.sandbox.patcher import apply_patch | |
| from rl_code_fix_env.src.sandbox.execution import run_test_file | |
| class CodeEnv: | |
| def __init__(self, max_steps: int = 10): | |
| self._data: dict = {} | |
| self.steps: int = 0 | |
| self.max_steps: int = max_steps | |
| self.episode_id: str = str(_uuid_module.uuid4()) | |
| self.tracer: TraceCollector = TraceCollector() | |
| self._state: dict = {} | |
| self.task_source: str = (os.getenv("TASK_SOURCE", "swebench") or "swebench").strip().lower() | |
| def reset(self, difficulty: str = "easy") -> dict: | |
| """ | |
| Load a task by difficulty and initialise episode state. | |
| Returns observation dict. | |
| """ | |
| self.steps = 0 | |
| self.episode_id = str(_uuid_module.uuid4()) | |
| self.tracer = TraceCollector() | |
| startup_log = None | |
| if self.task_source in {"swebench", "swebench_lite", "swebench-lite"}: | |
| try: | |
| self._data = get_swebench_task(difficulty) | |
| except Exception as exc: | |
| allow_fallback = ( | |
| os.getenv("SWEBENCH_FALLBACK_LOCAL", "1").strip().lower() | |
| in {"1", "true", "yes"} | |
| ) | |
| if not allow_fallback: | |
| raise | |
| startup_log = f"SWE-bench unavailable, fell back to local task: {exc}" | |
| self._data = get_hardcoded_task(difficulty) | |
| else: | |
| self._data = get_hardcoded_task(difficulty) | |
| self._state = { | |
| "code": self._data["code"], | |
| "test_path": self._data["tests"], | |
| "workspace": self._data["problem_dir"], | |
| "logs": startup_log, | |
| "test_score": 0.0, | |
| "prev_test_score": 0.0, # Track previous score for regression penalty | |
| "passed": 0, | |
| "total": 1, | |
| } | |
| return self._get_obs() | |
| def step(self, action: dict): | |
| """ | |
| Execute one action. | |
| action = { | |
| "type": "apply_patch" | "run_tests" | "get_logs", | |
| "payload": optional str | |
| } | |
| Returns: | |
| (obs_dict, reward, done, info) | |
| """ | |
| prev_state = self._state.copy() | |
| prev_test_score = self._state.get("test_score", 0.0) | |
| action_type = action.get("type", "") | |
| payload = action.get("payload") or "" | |
| if action_type == "apply_patch": | |
| patched_code, _ = apply_patch(self._state["code"], payload) | |
| self._state["code"] = patched_code | |
| # FIX: count apply_patch as a step so the budget is tracked correctly. | |
| # Previously only run_tests incremented steps, meaning MAX_STEPS was | |
| # never reached inside the env and done was never set by the env itself. | |
| self.steps += 1 | |
| elif action_type == "run_tests": | |
| # FIX: removed the extra self.steps += 1 here. | |
| # Steps now increment on every meaningful action (apply_patch), | |
| # not only on test runs. This ensures the step budget is consumed | |
| # correctly and `done` fires at the right time. | |
| passed, logs = run_test_file( | |
| code_file=self._state["code"], | |
| test_file=self._state["test_path"], | |
| workspace_dir=self._state["workspace"], | |
| ) | |
| # Parse test counts from logs if available (format: [TEST_COUNTS] passed=X total=Y) | |
| import re | |
| test_counts_match = re.search(r'\[TEST_COUNTS\]\s+passed=(\d+)\s+total=(\d+)', logs) | |
| if test_counts_match: | |
| passed_count = int(test_counts_match.group(1)) | |
| total_count = int(test_counts_match.group(2)) | |
| self._state["passed"] = passed_count | |
| self._state["total"] = max(total_count, 1) | |
| # Calculate partial score: passed/total — clamped to (0, 1) open interval | |
| # Validator rejects exact 0.0 and 1.0 | |
| _EPS = 1e-6 | |
| raw_score = passed_count / max(total_count, 1) | |
| self._state["test_score"] = max(_EPS, min(1.0 - _EPS, raw_score)) | |
| else: | |
| # Fallback to binary scoring if counts not found | |
| # Use epsilon-clamped values — validator rejects exact 0.0 and 1.0 | |
| _EPS = 1e-6 | |
| self._state["passed"] = 1 if passed else 0 | |
| self._state["total"] = 1 | |
| self._state["test_score"] = (1.0 - _EPS) if passed else _EPS | |
| self._state["logs"] = logs | |
| elif action_type == "get_logs": | |
| self._state["logs"] = self._state.get("logs") or "No logs yet. Run tests first." | |
| # Record trace step | |
| self.tracer.logs( | |
| state=prev_state, | |
| action=action, | |
| output=self._state, | |
| ) | |
| done = ( | |
| self._state["passed"] >= self._state["total"] # all tests pass | |
| or self.steps >= self.max_steps # step budget exhausted | |
| ) | |
| if action_type == "apply_patch": | |
| self._state["last_action_empty"] = not (payload and payload.strip()) | |
| last_action_empty = self._state.get("last_action_empty", False) | |
| try: | |
| reward = compute_reward( | |
| test_score=self._state["test_score"], | |
| trace_obj=self.tracer.steps, | |
| code=self._state["code"], | |
| steps_taken=self.steps, | |
| max_steps=self.max_steps, | |
| prev_test_score=prev_test_score, # Pass for regression penalty | |
| last_action_empty=last_action_empty, | |
| ) | |
| if self._state["passed"] >= self._state["total"]: | |
| # 1.0 is rejected by validator — use highest allowed value | |
| reward = 1.0 - 1e-6 | |
| except Exception: | |
| _EPS = 1e-6 | |
| reward = max(_EPS, min(1.0 - _EPS, float(self._state.get("test_score", _EPS)))) | |
| return self._get_obs(), float(reward), done, {} | |
| def _get_obs(self) -> dict: | |
| return { | |
| "code": self._state.get("code", ""), | |
| "logs": self._state.get("logs"), | |
| "test_score": float(self._state.get("test_score", 0.0)), | |
| "total_tests": int(self._state.get("total", 1)), | |
| "steps": self.steps, | |
| } |