Spaces:
Running
Running
File size: 6,853 Bytes
03a907a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import uuid as _uuid_module
import os
from rl_code_fix_env.dataset.loader import get_hardcoded_task
from rl_code_fix_env.dataset.swebench_adapter import get_swebench_task
from rl_code_fix_env.src.reward.reward import compute_reward
from rl_code_fix_env.src.trace.tracer import TraceCollector
from rl_code_fix_env.src.sandbox.patcher import apply_patch
from rl_code_fix_env.src.sandbox.execution import run_test_file
class CodeEnv:
def __init__(self, max_steps: int = 10):
self._data: dict = {}
self.steps: int = 0
self.max_steps: int = max_steps
self.episode_id: str = str(_uuid_module.uuid4())
self.tracer: TraceCollector = TraceCollector()
self._state: dict = {}
self.task_source: str = (os.getenv("TASK_SOURCE", "swebench") or "swebench").strip().lower()
def reset(self, difficulty: str = "easy") -> dict:
"""
Load a task by difficulty and initialise episode state.
Returns observation dict.
"""
self.steps = 0
self.episode_id = str(_uuid_module.uuid4())
self.tracer = TraceCollector()
startup_log = None
if self.task_source in {"swebench", "swebench_lite", "swebench-lite"}:
try:
self._data = get_swebench_task(difficulty)
except Exception as exc:
allow_fallback = (
os.getenv("SWEBENCH_FALLBACK_LOCAL", "1").strip().lower()
in {"1", "true", "yes"}
)
if not allow_fallback:
raise
startup_log = f"SWE-bench unavailable, fell back to local task: {exc}"
self._data = get_hardcoded_task(difficulty)
else:
self._data = get_hardcoded_task(difficulty)
self._state = {
"code": self._data["code"],
"test_path": self._data["tests"],
"workspace": self._data["problem_dir"],
"logs": startup_log,
"test_score": 0.0,
"prev_test_score": 0.0, # Track previous score for regression penalty
"passed": 0,
"total": 1,
}
return self._get_obs()
def step(self, action: dict):
"""
Execute one action.
action = {
"type": "apply_patch" | "run_tests" | "get_logs",
"payload": optional str
}
Returns:
(obs_dict, reward, done, info)
"""
prev_state = self._state.copy()
prev_test_score = self._state.get("test_score", 0.0)
action_type = action.get("type", "")
payload = action.get("payload") or ""
if action_type == "apply_patch":
patched_code, _ = apply_patch(self._state["code"], payload)
self._state["code"] = patched_code
# FIX: count apply_patch as a step so the budget is tracked correctly.
# Previously only run_tests incremented steps, meaning MAX_STEPS was
# never reached inside the env and done was never set by the env itself.
self.steps += 1
elif action_type == "run_tests":
# FIX: removed the extra self.steps += 1 here.
# Steps now increment on every meaningful action (apply_patch),
# not only on test runs. This ensures the step budget is consumed
# correctly and `done` fires at the right time.
passed, logs = run_test_file(
code_file=self._state["code"],
test_file=self._state["test_path"],
workspace_dir=self._state["workspace"],
)
# Parse test counts from logs if available (format: [TEST_COUNTS] passed=X total=Y)
import re
test_counts_match = re.search(r'\[TEST_COUNTS\]\s+passed=(\d+)\s+total=(\d+)', logs)
if test_counts_match:
passed_count = int(test_counts_match.group(1))
total_count = int(test_counts_match.group(2))
self._state["passed"] = passed_count
self._state["total"] = max(total_count, 1)
# Calculate partial score: passed/total — clamped to (0, 1) open interval
# Validator rejects exact 0.0 and 1.0
_EPS = 1e-6
raw_score = passed_count / max(total_count, 1)
self._state["test_score"] = max(_EPS, min(1.0 - _EPS, raw_score))
else:
# Fallback to binary scoring if counts not found
# Use epsilon-clamped values — validator rejects exact 0.0 and 1.0
_EPS = 1e-6
self._state["passed"] = 1 if passed else 0
self._state["total"] = 1
self._state["test_score"] = (1.0 - _EPS) if passed else _EPS
self._state["logs"] = logs
elif action_type == "get_logs":
self._state["logs"] = self._state.get("logs") or "No logs yet. Run tests first."
# Record trace step
self.tracer.logs(
state=prev_state,
action=action,
output=self._state,
)
done = (
self._state["passed"] >= self._state["total"] # all tests pass
or self.steps >= self.max_steps # step budget exhausted
)
if action_type == "apply_patch":
self._state["last_action_empty"] = not (payload and payload.strip())
last_action_empty = self._state.get("last_action_empty", False)
try:
reward = compute_reward(
test_score=self._state["test_score"],
trace_obj=self.tracer.steps,
code=self._state["code"],
steps_taken=self.steps,
max_steps=self.max_steps,
prev_test_score=prev_test_score, # Pass for regression penalty
last_action_empty=last_action_empty,
)
if self._state["passed"] >= self._state["total"]:
# 1.0 is rejected by validator — use highest allowed value
reward = 1.0 - 1e-6
except Exception:
_EPS = 1e-6
reward = max(_EPS, min(1.0 - _EPS, float(self._state.get("test_score", _EPS))))
return self._get_obs(), float(reward), done, {}
def _get_obs(self) -> dict:
return {
"code": self._state.get("code", ""),
"logs": self._state.get("logs"),
"test_score": float(self._state.get("test_score", 0.0)),
"total_tests": int(self._state.get("total", 1)),
"steps": self.steps,
} |