arnavzz Claude Sonnet 4.6 commited on
Commit
3faaaa0
·
1 Parent(s): b95e073

refactor: simplify and fix efficiency issues

Browse files

- executor: embed code via repr() to avoid double file I/O; extract _failure()
helper to unify error result structure; use rsplit for stdout parsing;
remove unnecessary WHAT comments
- environment: cap episodes at 500 with LRU eviction to prevent memory leak;
unify observation building into single _build_observation(); store
tests_passed/total_tests in episode during step() so state() reads
directly instead of recalculating; extract _get_episode() helper
- app: pass req.action.code directly instead of model_dump()
- inference: simplify strip_fences() using removeprefix/removesuffix

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

code_debug_env/server/app.py CHANGED
@@ -1,10 +1,4 @@
1
- """
2
- FastAPI server exposing the OpenEnv-compatible HTTP API.
3
- Port: 7860 (Hugging Face Spaces default)
4
- """
5
-
6
  from fastapi import FastAPI, HTTPException
7
- from fastapi.responses import JSONResponse
8
 
9
  from ..models import (
10
  DebugState,
@@ -24,10 +18,6 @@ app = FastAPI(
24
  env = CodeDebugEnvironment()
25
 
26
 
27
- # ------------------------------------------------------------------
28
- # Health & metadata
29
- # ------------------------------------------------------------------
30
-
31
  @app.get("/health")
32
  async def health():
33
  return {"status": "healthy", "tasks_loaded": len(env.tasks)}
@@ -38,17 +28,12 @@ async def list_tasks():
38
  return env.list_tasks()
39
 
40
 
41
- # ------------------------------------------------------------------
42
- # OpenEnv core endpoints
43
- # ------------------------------------------------------------------
44
-
45
  @app.post("/reset", response_model=ResetResponse)
46
  async def reset(req: ResetRequest = None):
47
  if req is None:
48
  req = ResetRequest()
49
  try:
50
- result = env.reset(task_id=req.task_id, seed=req.seed)
51
- return result
52
  except KeyError as e:
53
  raise HTTPException(status_code=404, detail=str(e))
54
 
@@ -56,8 +41,7 @@ async def reset(req: ResetRequest = None):
56
  @app.post("/step/{episode_id}", response_model=StepResponse)
57
  async def step(episode_id: str, req: StepRequest):
58
  try:
59
- result = env.step(episode_id, req.action.model_dump())
60
- return result
61
  except KeyError as e:
62
  raise HTTPException(status_code=404, detail=str(e))
63
  except ValueError as e:
@@ -70,12 +54,3 @@ async def state(episode_id: str):
70
  return env.state(episode_id)
71
  except KeyError as e:
72
  raise HTTPException(status_code=404, detail=str(e))
73
-
74
-
75
- # ------------------------------------------------------------------
76
- # Entry point for local dev
77
- # ------------------------------------------------------------------
78
-
79
- if __name__ == "__main__":
80
- import uvicorn
81
- uvicorn.run("code_debug_env.server.app:app", host="0.0.0.0", port=7860, reload=True)
 
 
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
 
2
 
3
  from ..models import (
4
  DebugState,
 
18
  env = CodeDebugEnvironment()
19
 
20
 
 
 
 
 
21
  @app.get("/health")
22
  async def health():
23
  return {"status": "healthy", "tasks_loaded": len(env.tasks)}
 
28
  return env.list_tasks()
29
 
30
 
 
 
 
 
31
  @app.post("/reset", response_model=ResetResponse)
32
  async def reset(req: ResetRequest = None):
33
  if req is None:
34
  req = ResetRequest()
35
  try:
36
+ return env.reset(task_id=req.task_id, seed=req.seed)
 
37
  except KeyError as e:
38
  raise HTTPException(status_code=404, detail=str(e))
39
 
 
41
  @app.post("/step/{episode_id}", response_model=StepResponse)
42
  async def step(episode_id: str, req: StepRequest):
43
  try:
44
+ return env.step(episode_id, req.action.code)
 
45
  except KeyError as e:
46
  raise HTTPException(status_code=404, detail=str(e))
47
  except ValueError as e:
 
54
  return env.state(episode_id)
55
  except KeyError as e:
56
  raise HTTPException(status_code=404, detail=str(e))
 
 
 
 
 
 
 
 
 
code_debug_env/server/environment.py CHANGED
@@ -5,15 +5,20 @@ Core environment logic: task loading, reset, step, state.
5
  import json
6
  import random
7
  import uuid
 
8
  from pathlib import Path
9
 
10
  from .executor import run_code_safely
11
 
 
 
 
12
 
13
  class CodeDebugEnvironment:
14
  def __init__(self):
15
  self.tasks: dict[str, dict] = {}
16
- self.episodes: dict[str, dict] = {}
 
17
  self._load_tasks()
18
 
19
  def _load_tasks(self):
@@ -23,10 +28,6 @@ class CodeDebugEnvironment:
23
  task = json.load(f)
24
  self.tasks[task["task_id"]] = task
25
 
26
- # ------------------------------------------------------------------
27
- # Public API
28
- # ------------------------------------------------------------------
29
-
30
  def reset(self, task_id: str | None = None, seed: int | None = None) -> dict:
31
  if seed is not None:
32
  random.seed(seed)
@@ -40,38 +41,34 @@ class CodeDebugEnvironment:
40
  task = self.tasks[task_id]
41
  episode_id = str(uuid.uuid4())
42
 
 
 
 
43
  self.episodes[episode_id] = {
44
  "episode_id": episode_id,
45
  "task": task,
46
  "step_count": 0,
47
  "done": False,
48
  "rewards": [],
 
 
49
  "last_test_results": [],
50
  }
51
 
52
- observation = self._initial_observation(task)
53
- return {"episode_id": episode_id, "observation": observation}
54
 
55
- def step(self, episode_id: str, action: dict) -> dict:
56
- if episode_id not in self.episodes:
57
- raise KeyError(f"Unknown episode_id: {episode_id!r}")
58
-
59
- ep = self.episodes[episode_id]
60
  if ep["done"]:
61
  raise ValueError("Episode is already finished. Call reset() to start a new episode.")
62
 
63
  task = ep["task"]
64
- submitted_code = action.get("code", "")
65
  ep["step_count"] += 1
66
 
67
- test_results_raw, stdout, stderr = run_code_safely(
68
- submitted_code,
69
- task["test_code"],
70
- timeout=10,
71
- )
72
 
73
- tests_passed = sum(1 for t in test_results_raw if t.get("passed", False))
74
- total_tests = len(test_results_raw)
75
  reward = round(tests_passed / total_tests, 4) if total_tests > 0 else 0.0
76
 
77
  max_steps = task.get("max_steps", 5)
@@ -79,45 +76,31 @@ class CodeDebugEnvironment:
79
 
80
  ep["done"] = done
81
  ep["rewards"].append(reward)
82
- ep["last_test_results"] = test_results_raw
 
 
 
 
83
 
84
- observation = {
85
- "task_id": task["task_id"],
86
- "difficulty": task["difficulty"],
87
- "description": task["description"],
88
- "buggy_code": task["buggy_code"],
89
- "test_descriptions": task["test_descriptions"],
90
- "test_results": test_results_raw,
91
- "stdout": stdout,
92
- "stderr": stderr,
93
- "step_count": ep["step_count"],
94
- "max_steps": max_steps,
95
  "reward": reward,
96
  "done": done,
97
- "total_tests": total_tests,
98
- "tests_passed": tests_passed,
99
  }
100
 
101
- return {"observation": observation, "reward": reward, "done": done, "info": {}}
102
-
103
  def state(self, episode_id: str) -> dict:
104
- if episode_id not in self.episodes:
105
- raise KeyError(f"Unknown episode_id: {episode_id!r}")
106
-
107
- ep = self.episodes[episode_id]
108
- task = ep["task"]
109
- last_results = ep.get("last_test_results", [])
110
-
111
  return {
112
  "episode_id": episode_id,
113
- "task_id": task["task_id"],
114
- "difficulty": task["difficulty"],
115
  "step_count": ep["step_count"],
116
- "max_steps": task.get("max_steps", 5),
117
  "last_reward": ep["rewards"][-1] if ep["rewards"] else 0.0,
118
  "cumulative_reward": round(sum(ep["rewards"]), 4),
119
- "tests_passed": sum(1 for t in last_results if t.get("passed", False)),
120
- "total_tests": len(last_results),
121
  "done": ep["done"],
122
  }
123
 
@@ -133,24 +116,27 @@ class CodeDebugEnvironment:
133
  for t in self.tasks.values()
134
  ]
135
 
136
- # ------------------------------------------------------------------
137
- # Internal helpers
138
- # ------------------------------------------------------------------
 
139
 
140
- def _initial_observation(self, task: dict) -> dict:
 
 
141
  return {
142
  "task_id": task["task_id"],
143
  "difficulty": task["difficulty"],
144
  "description": task["description"],
145
  "buggy_code": task["buggy_code"],
146
  "test_descriptions": task["test_descriptions"],
147
- "test_results": [],
148
- "stdout": "",
149
- "stderr": "",
150
- "step_count": 0,
151
  "max_steps": task.get("max_steps", 5),
152
- "reward": 0.0,
153
- "done": False,
154
- "total_tests": len(task["test_descriptions"]),
155
- "tests_passed": 0,
156
  }
 
5
  import json
6
  import random
7
  import uuid
8
+ from collections import OrderedDict
9
  from pathlib import Path
10
 
11
  from .executor import run_code_safely
12
 
13
+ # Cap in-memory episodes to prevent unbounded growth.
14
+ _MAX_EPISODES = 500
15
+
16
 
17
  class CodeDebugEnvironment:
18
  def __init__(self):
19
  self.tasks: dict[str, dict] = {}
20
+ # OrderedDict used as a simple LRU: oldest episode evicted when cap is hit.
21
+ self.episodes: OrderedDict[str, dict] = OrderedDict()
22
  self._load_tasks()
23
 
24
  def _load_tasks(self):
 
28
  task = json.load(f)
29
  self.tasks[task["task_id"]] = task
30
 
 
 
 
 
31
  def reset(self, task_id: str | None = None, seed: int | None = None) -> dict:
32
  if seed is not None:
33
  random.seed(seed)
 
41
  task = self.tasks[task_id]
42
  episode_id = str(uuid.uuid4())
43
 
44
+ if len(self.episodes) >= _MAX_EPISODES:
45
+ self.episodes.popitem(last=False)
46
+
47
  self.episodes[episode_id] = {
48
  "episode_id": episode_id,
49
  "task": task,
50
  "step_count": 0,
51
  "done": False,
52
  "rewards": [],
53
+ "tests_passed": 0,
54
+ "total_tests": len(task["test_descriptions"]),
55
  "last_test_results": [],
56
  }
57
 
58
+ return {"episode_id": episode_id, "observation": self._build_observation(episode_id)}
 
59
 
60
+ def step(self, episode_id: str, code: str) -> dict:
61
+ ep = self._get_episode(episode_id)
 
 
 
62
  if ep["done"]:
63
  raise ValueError("Episode is already finished. Call reset() to start a new episode.")
64
 
65
  task = ep["task"]
 
66
  ep["step_count"] += 1
67
 
68
+ test_results, stdout, stderr = run_code_safely(code, task["test_code"], timeout=10)
 
 
 
 
69
 
70
+ tests_passed = sum(1 for t in test_results if t.get("passed", False))
71
+ total_tests = len(test_results)
72
  reward = round(tests_passed / total_tests, 4) if total_tests > 0 else 0.0
73
 
74
  max_steps = task.get("max_steps", 5)
 
76
 
77
  ep["done"] = done
78
  ep["rewards"].append(reward)
79
+ ep["tests_passed"] = tests_passed
80
+ ep["total_tests"] = total_tests
81
+ ep["last_test_results"] = test_results
82
+ ep["last_stdout"] = stdout
83
+ ep["last_stderr"] = stderr
84
 
85
+ return {
86
+ "observation": self._build_observation(episode_id),
 
 
 
 
 
 
 
 
 
87
  "reward": reward,
88
  "done": done,
89
+ "info": {},
 
90
  }
91
 
 
 
92
  def state(self, episode_id: str) -> dict:
93
+ ep = self._get_episode(episode_id)
 
 
 
 
 
 
94
  return {
95
  "episode_id": episode_id,
96
+ "task_id": ep["task"]["task_id"],
97
+ "difficulty": ep["task"]["difficulty"],
98
  "step_count": ep["step_count"],
99
+ "max_steps": ep["task"].get("max_steps", 5),
100
  "last_reward": ep["rewards"][-1] if ep["rewards"] else 0.0,
101
  "cumulative_reward": round(sum(ep["rewards"]), 4),
102
+ "tests_passed": ep["tests_passed"],
103
+ "total_tests": ep["total_tests"],
104
  "done": ep["done"],
105
  }
106
 
 
116
  for t in self.tasks.values()
117
  ]
118
 
119
+ def _get_episode(self, episode_id: str) -> dict:
120
+ if episode_id not in self.episodes:
121
+ raise KeyError(f"Unknown episode_id: {episode_id!r}")
122
+ return self.episodes[episode_id]
123
 
124
+ def _build_observation(self, episode_id: str) -> dict:
125
+ ep = self.episodes[episode_id]
126
+ task = ep["task"]
127
  return {
128
  "task_id": task["task_id"],
129
  "difficulty": task["difficulty"],
130
  "description": task["description"],
131
  "buggy_code": task["buggy_code"],
132
  "test_descriptions": task["test_descriptions"],
133
+ "test_results": ep["last_test_results"],
134
+ "stdout": ep.get("last_stdout", ""),
135
+ "stderr": ep.get("last_stderr", ""),
136
+ "step_count": ep["step_count"],
137
  "max_steps": task.get("max_steps", 5),
138
+ "reward": ep["rewards"][-1] if ep["rewards"] else 0.0,
139
+ "done": ep["done"],
140
+ "total_tests": ep["total_tests"],
141
+ "tests_passed": ep["tests_passed"],
142
  }
code_debug_env/server/executor.py CHANGED
@@ -1,9 +1,5 @@
1
  """
2
- Safe code execution engine.
3
-
4
- Runs submitted code in a subprocess with timeout.
5
- Writes code to a temp directory, generates a test harness,
6
- and parses JSON results from stdout.
7
  """
8
 
9
  import json
@@ -22,40 +18,26 @@ def run_code_safely(
22
  """
23
  Execute submitted code against test cases in an isolated subprocess.
24
 
25
- Args:
26
- submitted_code: The Python code the agent submitted as a fix.
27
- test_code: Python snippet that populates a `results` list with test dicts.
28
- timeout: Max seconds before killing the subprocess.
29
-
30
- Returns:
31
- (test_results, stdout_extra, stderr) where test_results is a list of
32
- {"test_name", "passed", "expected", "actual", "error"} dicts.
33
  """
34
  with tempfile.TemporaryDirectory() as tmpdir:
35
- solution_path = Path(tmpdir) / "solution.py"
36
  harness_path = Path(tmpdir) / "harness.py"
37
 
38
- # Write the submitted code as a module
39
- solution_path.write_text(submitted_code, encoding="utf-8")
40
-
41
- # Build the test harness
42
  harness = textwrap.dedent(f"""\
43
  import sys, json, traceback
44
- sys.path.insert(0, r"{tmpdir}")
45
 
46
  results = []
47
 
48
- # Execute the submitted solution in this namespace
49
  try:
50
- exec(open(r"{solution_path}", encoding="utf-8").read())
51
- except Exception as e:
52
- # If the solution itself fails to load, all tests fail
53
  print(json.dumps([{{"test_name": "load", "passed": False,
54
  "expected": "code loads", "actual": "",
55
  "error": traceback.format_exc()}}]))
56
  sys.exit(0)
57
 
58
- # Run the test code (populates `results`)
59
  {textwrap.indent(test_code, " ").strip()}
60
 
61
  print(json.dumps(results))
@@ -74,44 +56,22 @@ def run_code_safely(
74
  stdout = proc.stdout.strip()
75
  stderr = proc.stderr.strip()
76
 
77
- # Parse test results from last line of stdout (the JSON array)
78
  if stdout:
79
- # The JSON array should be the last line
80
- lines = stdout.split("\n")
81
- json_line = lines[-1]
82
- extra_output = "\n".join(lines[:-1]) if len(lines) > 1 else ""
83
  try:
84
- test_results = json.loads(json_line)
85
- return test_results, extra_output, stderr
86
  except json.JSONDecodeError:
87
- return [
88
- {
89
- "test_name": "parse_error",
90
- "passed": False,
91
- "expected": "valid JSON output",
92
- "actual": stdout[:200],
93
- "error": "Could not parse test results from subprocess output",
94
- }
95
- ], "", stderr
96
-
97
- # No stdout at all — likely a crash
98
- return [
99
- {
100
- "test_name": "execution_error",
101
- "passed": False,
102
- "expected": "code runs",
103
- "actual": "",
104
- "error": stderr[:500] if stderr else "No output produced",
105
- }
106
- ], "", stderr
107
 
108
  except subprocess.TimeoutExpired:
109
- return [
110
- {
111
- "test_name": "timeout",
112
- "passed": False,
113
- "expected": f"completes within {timeout}s",
114
- "actual": "timed out",
115
- "error": f"Code execution exceeded {timeout} second timeout",
116
- }
117
- ], "", "Execution timed out"
 
1
  """
2
+ Safe code execution engine using subprocess with timeout.
 
 
 
 
3
  """
4
 
5
  import json
 
18
  """
19
  Execute submitted code against test cases in an isolated subprocess.
20
 
21
+ Returns (test_results, extra_stdout, stderr).
22
+ test_results is a list of {test_name, passed, expected, actual, error} dicts.
 
 
 
 
 
 
23
  """
24
  with tempfile.TemporaryDirectory() as tmpdir:
 
25
  harness_path = Path(tmpdir) / "harness.py"
26
 
27
+ # Embed submitted code directly via repr() to avoid a second file read.
 
 
 
28
  harness = textwrap.dedent(f"""\
29
  import sys, json, traceback
 
30
 
31
  results = []
32
 
 
33
  try:
34
+ exec({repr(submitted_code)})
35
+ except Exception:
 
36
  print(json.dumps([{{"test_name": "load", "passed": False,
37
  "expected": "code loads", "actual": "",
38
  "error": traceback.format_exc()}}]))
39
  sys.exit(0)
40
 
 
41
  {textwrap.indent(test_code, " ").strip()}
42
 
43
  print(json.dumps(results))
 
56
  stdout = proc.stdout.strip()
57
  stderr = proc.stderr.strip()
58
 
 
59
  if stdout:
60
+ # JSON array is always the last line; anything before it is extra output.
61
+ extra_output, json_line = stdout.rsplit("\n", maxsplit=1) if "\n" in stdout else ("", stdout)
 
 
62
  try:
63
+ return json.loads(json_line), extra_output, stderr
 
64
  except json.JSONDecodeError:
65
+ return _failure("parse_error", "valid JSON output", stdout[:200],
66
+ "Could not parse test results from subprocess output"), "", stderr
67
+
68
+ return _failure("execution_error", "code runs", "",
69
+ stderr[:500] if stderr else "No output produced"), "", stderr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  except subprocess.TimeoutExpired:
72
+ return _failure("timeout", f"completes within {timeout}s", "timed out",
73
+ f"Code execution exceeded {timeout} second timeout"), "", "Execution timed out"
74
+
75
+
76
+ def _failure(name: str, expected: str, actual: str, error: str) -> list[dict]:
77
+ return [{"test_name": name, "passed": False, "expected": expected, "actual": actual, "error": error}]
 
 
 
inference.py CHANGED
@@ -111,12 +111,8 @@ def build_feedback_prompt(obs: dict) -> str:
111
 
112
  def strip_fences(text: str) -> str:
113
  text = text.strip()
114
- if text.startswith("```python"):
115
- text = text[len("```python"):].strip()
116
- elif text.startswith("```"):
117
- text = text[3:].strip()
118
- if text.endswith("```"):
119
- text = text[:-3].strip()
120
  return text
121
 
122
 
 
111
 
112
  def strip_fences(text: str) -> str:
113
  text = text.strip()
114
+ text = text.removeprefix("```python").removeprefix("```").strip()
115
+ text = text.removesuffix("```").strip()
 
 
 
 
116
  return text
117
 
118