import uuid as _uuid_module import os from rl_code_fix_env.dataset.loader import get_hardcoded_task from rl_code_fix_env.dataset.swebench_adapter import get_swebench_task from rl_code_fix_env.src.reward.reward import compute_reward from rl_code_fix_env.src.trace.tracer import TraceCollector from rl_code_fix_env.src.sandbox.patcher import apply_patch from rl_code_fix_env.src.sandbox.execution import run_test_file class CodeEnv: def __init__(self, max_steps: int = 10): self._data: dict = {} self.steps: int = 0 self.max_steps: int = max_steps self.episode_id: str = str(_uuid_module.uuid4()) self.tracer: TraceCollector = TraceCollector() self._state: dict = {} self.task_source: str = (os.getenv("TASK_SOURCE", "swebench") or "swebench").strip().lower() def reset(self, difficulty: str = "easy") -> dict: """ Load a task by difficulty and initialise episode state. Returns observation dict. """ self.steps = 0 self.episode_id = str(_uuid_module.uuid4()) self.tracer = TraceCollector() startup_log = None if self.task_source in {"swebench", "swebench_lite", "swebench-lite"}: try: self._data = get_swebench_task(difficulty) except Exception as exc: allow_fallback = ( os.getenv("SWEBENCH_FALLBACK_LOCAL", "1").strip().lower() in {"1", "true", "yes"} ) if not allow_fallback: raise startup_log = f"SWE-bench unavailable, fell back to local task: {exc}" self._data = get_hardcoded_task(difficulty) else: self._data = get_hardcoded_task(difficulty) self._state = { "code": self._data["code"], "test_path": self._data["tests"], "workspace": self._data["problem_dir"], "logs": startup_log, "test_score": 0.0, "prev_test_score": 0.0, # Track previous score for regression penalty "passed": 0, "total": 1, } return self._get_obs() def step(self, action: dict): """ Execute one action. action = { "type": "apply_patch" | "run_tests" | "get_logs", "payload": optional str } Returns: (obs_dict, reward, done, info) """ prev_state = self._state.copy() prev_test_score = self._state.get("test_score", 0.0) action_type = action.get("type", "") payload = action.get("payload") or "" if action_type == "apply_patch": patched_code, _ = apply_patch(self._state["code"], payload) self._state["code"] = patched_code # FIX: count apply_patch as a step so the budget is tracked correctly. # Previously only run_tests incremented steps, meaning MAX_STEPS was # never reached inside the env and done was never set by the env itself. self.steps += 1 elif action_type == "run_tests": # FIX: removed the extra self.steps += 1 here. # Steps now increment on every meaningful action (apply_patch), # not only on test runs. This ensures the step budget is consumed # correctly and `done` fires at the right time. passed, logs = run_test_file( code_file=self._state["code"], test_file=self._state["test_path"], workspace_dir=self._state["workspace"], ) # Parse test counts from logs if available (format: [TEST_COUNTS] passed=X total=Y) import re test_counts_match = re.search(r'\[TEST_COUNTS\]\s+passed=(\d+)\s+total=(\d+)', logs) if test_counts_match: passed_count = int(test_counts_match.group(1)) total_count = int(test_counts_match.group(2)) self._state["passed"] = passed_count self._state["total"] = max(total_count, 1) # Calculate partial score: passed/total — clamped to (0, 1) open interval # Validator rejects exact 0.0 and 1.0 _EPS = 1e-6 raw_score = passed_count / max(total_count, 1) self._state["test_score"] = max(_EPS, min(1.0 - _EPS, raw_score)) else: # Fallback to binary scoring if counts not found # Use epsilon-clamped values — validator rejects exact 0.0 and 1.0 _EPS = 1e-6 self._state["passed"] = 1 if passed else 0 self._state["total"] = 1 self._state["test_score"] = (1.0 - _EPS) if passed else _EPS self._state["logs"] = logs elif action_type == "get_logs": self._state["logs"] = self._state.get("logs") or "No logs yet. Run tests first." # Record trace step self.tracer.logs( state=prev_state, action=action, output=self._state, ) done = ( self._state["passed"] >= self._state["total"] # all tests pass or self.steps >= self.max_steps # step budget exhausted ) if action_type == "apply_patch": self._state["last_action_empty"] = not (payload and payload.strip()) last_action_empty = self._state.get("last_action_empty", False) try: reward = compute_reward( test_score=self._state["test_score"], trace_obj=self.tracer.steps, code=self._state["code"], steps_taken=self.steps, max_steps=self.max_steps, prev_test_score=prev_test_score, # Pass for regression penalty last_action_empty=last_action_empty, ) if self._state["passed"] >= self._state["total"]: # 1.0 is rejected by validator — use highest allowed value reward = 1.0 - 1e-6 except Exception: _EPS = 1e-6 reward = max(_EPS, min(1.0 - _EPS, float(self._state.get("test_score", _EPS)))) return self._get_obs(), float(reward), done, {} def _get_obs(self) -> dict: return { "code": self._state.get("code", ""), "logs": self._state.get("logs"), "test_score": float(self._state.get("test_score", 0.0)), "total_tests": int(self._state.get("total", 1)), "steps": self.steps, }