rl_code_fix_env / src /environment /environment.py
Viraj0112's picture
Upload folder using huggingface_hub
03a907a verified
import uuid as _uuid_module
import os
from rl_code_fix_env.dataset.loader import get_hardcoded_task
from rl_code_fix_env.dataset.swebench_adapter import get_swebench_task
from rl_code_fix_env.src.reward.reward import compute_reward
from rl_code_fix_env.src.trace.tracer import TraceCollector
from rl_code_fix_env.src.sandbox.patcher import apply_patch
from rl_code_fix_env.src.sandbox.execution import run_test_file
class CodeEnv:
def __init__(self, max_steps: int = 10):
self._data: dict = {}
self.steps: int = 0
self.max_steps: int = max_steps
self.episode_id: str = str(_uuid_module.uuid4())
self.tracer: TraceCollector = TraceCollector()
self._state: dict = {}
self.task_source: str = (os.getenv("TASK_SOURCE", "swebench") or "swebench").strip().lower()
def reset(self, difficulty: str = "easy") -> dict:
"""
Load a task by difficulty and initialise episode state.
Returns observation dict.
"""
self.steps = 0
self.episode_id = str(_uuid_module.uuid4())
self.tracer = TraceCollector()
startup_log = None
if self.task_source in {"swebench", "swebench_lite", "swebench-lite"}:
try:
self._data = get_swebench_task(difficulty)
except Exception as exc:
allow_fallback = (
os.getenv("SWEBENCH_FALLBACK_LOCAL", "1").strip().lower()
in {"1", "true", "yes"}
)
if not allow_fallback:
raise
startup_log = f"SWE-bench unavailable, fell back to local task: {exc}"
self._data = get_hardcoded_task(difficulty)
else:
self._data = get_hardcoded_task(difficulty)
self._state = {
"code": self._data["code"],
"test_path": self._data["tests"],
"workspace": self._data["problem_dir"],
"logs": startup_log,
"test_score": 0.0,
"prev_test_score": 0.0, # Track previous score for regression penalty
"passed": 0,
"total": 1,
}
return self._get_obs()
def step(self, action: dict):
"""
Execute one action.
action = {
"type": "apply_patch" | "run_tests" | "get_logs",
"payload": optional str
}
Returns:
(obs_dict, reward, done, info)
"""
prev_state = self._state.copy()
prev_test_score = self._state.get("test_score", 0.0)
action_type = action.get("type", "")
payload = action.get("payload") or ""
if action_type == "apply_patch":
patched_code, _ = apply_patch(self._state["code"], payload)
self._state["code"] = patched_code
# FIX: count apply_patch as a step so the budget is tracked correctly.
# Previously only run_tests incremented steps, meaning MAX_STEPS was
# never reached inside the env and done was never set by the env itself.
self.steps += 1
elif action_type == "run_tests":
# FIX: removed the extra self.steps += 1 here.
# Steps now increment on every meaningful action (apply_patch),
# not only on test runs. This ensures the step budget is consumed
# correctly and `done` fires at the right time.
passed, logs = run_test_file(
code_file=self._state["code"],
test_file=self._state["test_path"],
workspace_dir=self._state["workspace"],
)
# Parse test counts from logs if available (format: [TEST_COUNTS] passed=X total=Y)
import re
test_counts_match = re.search(r'\[TEST_COUNTS\]\s+passed=(\d+)\s+total=(\d+)', logs)
if test_counts_match:
passed_count = int(test_counts_match.group(1))
total_count = int(test_counts_match.group(2))
self._state["passed"] = passed_count
self._state["total"] = max(total_count, 1)
# Calculate partial score: passed/total — clamped to (0, 1) open interval
# Validator rejects exact 0.0 and 1.0
_EPS = 1e-6
raw_score = passed_count / max(total_count, 1)
self._state["test_score"] = max(_EPS, min(1.0 - _EPS, raw_score))
else:
# Fallback to binary scoring if counts not found
# Use epsilon-clamped values — validator rejects exact 0.0 and 1.0
_EPS = 1e-6
self._state["passed"] = 1 if passed else 0
self._state["total"] = 1
self._state["test_score"] = (1.0 - _EPS) if passed else _EPS
self._state["logs"] = logs
elif action_type == "get_logs":
self._state["logs"] = self._state.get("logs") or "No logs yet. Run tests first."
# Record trace step
self.tracer.logs(
state=prev_state,
action=action,
output=self._state,
)
done = (
self._state["passed"] >= self._state["total"] # all tests pass
or self.steps >= self.max_steps # step budget exhausted
)
if action_type == "apply_patch":
self._state["last_action_empty"] = not (payload and payload.strip())
last_action_empty = self._state.get("last_action_empty", False)
try:
reward = compute_reward(
test_score=self._state["test_score"],
trace_obj=self.tracer.steps,
code=self._state["code"],
steps_taken=self.steps,
max_steps=self.max_steps,
prev_test_score=prev_test_score, # Pass for regression penalty
last_action_empty=last_action_empty,
)
if self._state["passed"] >= self._state["total"]:
# 1.0 is rejected by validator — use highest allowed value
reward = 1.0 - 1e-6
except Exception:
_EPS = 1e-6
reward = max(_EPS, min(1.0 - _EPS, float(self._state.get("test_score", _EPS))))
return self._get_obs(), float(reward), done, {}
def _get_obs(self) -> dict:
return {
"code": self._state.get("code", ""),
"logs": self._state.get("logs"),
"test_score": float(self._state.get("test_score", 0.0)),
"total_tests": int(self._state.get("total", 1)),
"steps": self.steps,
}