The-Fool-09 commited on
Commit
8412998
·
verified ·
1 Parent(s): 22b11ca

Upload folder using huggingface_hub

Browse files
Files changed (10) hide show
  1. Blog.md +98 -98
  2. Dockerfile +1 -0
  3. client.py +23 -23
  4. eval/api_baseline.py +346 -346
  5. inference.py +332 -332
  6. models.py +9 -9
  7. server/debugZero_environment.py +224 -224
  8. server/graders.py +19 -19
  9. server/tasks.py +92 -92
  10. validate-submission.sh +184 -184
Blog.md CHANGED
@@ -1,98 +1,98 @@
1
- # DebugZero: Teaching a Coding Agent to Create and Fix Bugs
2
-
3
- Most code benchmarks ask a model to write a fresh solution from scratch. That is useful, but it skips a big part of real programming work: debugging code that is almost correct.
4
-
5
- That is the problem we built **DebugZero** to explore.
6
-
7
- DebugZero is an OpenEnv environment where a coding agent learns through a two-role game:
8
-
9
- - a **Proposer** takes a correct function and introduces a small but meaningful bug
10
- - a **Solver** takes that buggy function and tries to repair it
11
-
12
- The environment runs the submitted code in a sandbox, executes tests, and returns structured observations and rewards. In other words, the model does not just generate code and hope for the best. It acts inside an environment that can tell it whether a bug is real, whether a fix works, and whether the behavior is improving over time.
13
-
14
- ## Why we built it
15
-
16
- We wanted an environment that treats debugging as a first-class skill.
17
-
18
- In practice, strong programmers do more than write correct code. They also:
19
-
20
- - recognize how correct-looking code can fail
21
- - make small, targeted edits instead of rewriting everything
22
- - use test failures as evidence
23
- - recover from mistakes efficiently
24
-
25
- Static benchmarks usually measure the end result. DebugZero is meant to train the process.
26
-
27
- ## How an episode works
28
-
29
- Each episode starts from a clean seed task: a short Python function plus a hidden test harness.
30
-
31
- On the first turn, the proposer submits a modified version of the function. The goal is not to destroy the program randomly. The goal is to create a bug that is realistic, small, and detectable by tests.
32
-
33
- The environment then:
34
-
35
- 1. parses the submitted code
36
- 2. executes it in a sandboxed subprocess
37
- 3. runs the task tests
38
- 4. returns the current code, execution result, test status, reward, and next role
39
-
40
- If the proposer successfully creates a valid bug, the solver gets the next turn. The solver then submits a repaired function, and the environment checks whether the original behavior has been restored.
41
-
42
- This makes the whole loop executable and grounded. The agent is not rewarded for sounding plausible. It is rewarded for actually changing program behavior in the intended way.
43
-
44
- ## What makes the reward signal useful
45
-
46
- DebugZero uses role-aware rewards instead of a single generic success metric.
47
-
48
- For the proposer, reward is higher when the bug is:
49
-
50
- - syntactically valid
51
- - actually test-breaking
52
- - close to the original implementation rather than random corruption
53
-
54
- For the solver, reward is higher when the fix cleanly restores the expected behavior.
55
-
56
- That design matters because it pushes both roles toward realistic debugging behavior. The proposer learns to create useful failures. The solver learns to make precise repairs.
57
-
58
- ## What we trained
59
-
60
- We trained a policy for this environment using **GRPO** and role-conditioned prompting. One important design choice was to train against the **deployed environment itself**, not against notebook-local copies of the environment logic.
61
-
62
- That means the training loop interacts with the same OpenEnv interface that serves the environment in deployment:
63
-
64
- - reset the environment
65
- - observe the current task state
66
- - submit a proposer or solver action
67
- - receive reward and updated observation
68
-
69
- This kept training aligned with the real environment instead of drifting into a separate offline approximation.
70
-
71
- ## Why the two-role setup is interesting
72
-
73
- The most fun part of DebugZero is that it creates its own pressure to improve.
74
-
75
- If the solver becomes stronger, the proposer has to invent better bugs. If the proposer becomes better at making subtle failures, the solver has to become more precise at repair. That gives us a natural self-play curriculum for debugging.
76
-
77
- Instead of hand-authoring every training example, we get an environment where challenge and skill can rise together.
78
-
79
- ## What DebugZero is really trying to test
80
-
81
- At a deeper level, this project is about whether coding agents can become better debuggers through interaction rather than static supervision alone.
82
-
83
- We care about questions like:
84
-
85
- - Can an agent learn to create realistic failure modes?
86
- - Can it repair bugs without over-editing the program?
87
- - Can self-play produce a useful curriculum for code reasoning?
88
- - Can reward grounded in execution and tests teach something that static datasets miss?
89
-
90
- DebugZero is our attempt at turning those questions into something concrete and measurable.
91
-
92
- ## Links
93
-
94
- - Hugging Face Space: https://the-fool-09-debugzero.hf.space
95
- - Hugging Face project page: https://huggingface.co/spaces/The-Fool-09/debugZero
96
- - Training notebook: `notebooks/train_colab_updated_1.ipynb`
97
-
98
- In short, DebugZero is not just a benchmark where a model writes code. It is an environment where the model learns from failure, creates new failure cases, and improves through the loop of breaking and repairing programs. That is the behavior we wanted to surface, and that is what we trained for.
 
1
+ # DebugZero: Teaching a Coding Agent to Create and Fix Bugs
2
+
3
+ Most code benchmarks ask a model to write a fresh solution from scratch. That is useful, but it skips a big part of real programming work: debugging code that is almost correct.
4
+
5
+ That is the problem we built **DebugZero** to explore.
6
+
7
+ DebugZero is an OpenEnv environment where a coding agent learns through a two-role game:
8
+
9
+ - a **Proposer** takes a correct function and introduces a small but meaningful bug
10
+ - a **Solver** takes that buggy function and tries to repair it
11
+
12
+ The environment runs the submitted code in a sandbox, executes tests, and returns structured observations and rewards. In other words, the model does not just generate code and hope for the best. It acts inside an environment that can tell it whether a bug is real, whether a fix works, and whether the behavior is improving over time.
13
+
14
+ ## Why we built it
15
+
16
+ We wanted an environment that treats debugging as a first-class skill.
17
+
18
+ In practice, strong programmers do more than write correct code. They also:
19
+
20
+ - recognize how correct-looking code can fail
21
+ - make small, targeted edits instead of rewriting everything
22
+ - use test failures as evidence
23
+ - recover from mistakes efficiently
24
+
25
+ Static benchmarks usually measure the end result. DebugZero is meant to train the process.
26
+
27
+ ## How an episode works
28
+
29
+ Each episode starts from a clean seed task: a short Python function plus a hidden test harness.
30
+
31
+ On the first turn, the proposer submits a modified version of the function. The goal is not to destroy the program randomly. The goal is to create a bug that is realistic, small, and detectable by tests.
32
+
33
+ The environment then:
34
+
35
+ 1. parses the submitted code
36
+ 2. executes it in a sandboxed subprocess
37
+ 3. runs the task tests
38
+ 4. returns the current code, execution result, test status, reward, and next role
39
+
40
+ If the proposer successfully creates a valid bug, the solver gets the next turn. The solver then submits a repaired function, and the environment checks whether the original behavior has been restored.
41
+
42
+ This makes the whole loop executable and grounded. The agent is not rewarded for sounding plausible. It is rewarded for actually changing program behavior in the intended way.
43
+
44
+ ## What makes the reward signal useful
45
+
46
+ DebugZero uses role-aware rewards instead of a single generic success metric.
47
+
48
+ For the proposer, reward is higher when the bug is:
49
+
50
+ - syntactically valid
51
+ - actually test-breaking
52
+ - close to the original implementation rather than random corruption
53
+
54
+ For the solver, reward is higher when the fix cleanly restores the expected behavior.
55
+
56
+ That design matters because it pushes both roles toward realistic debugging behavior. The proposer learns to create useful failures. The solver learns to make precise repairs.
57
+
58
+ ## What we trained
59
+
60
+ We trained a policy for this environment using **GRPO** and role-conditioned prompting. One important design choice was to train against the **deployed environment itself**, not against notebook-local copies of the environment logic.
61
+
62
+ That means the training loop interacts with the same OpenEnv interface that serves the environment in deployment:
63
+
64
+ - reset the environment
65
+ - observe the current task state
66
+ - submit a proposer or solver action
67
+ - receive reward and updated observation
68
+
69
+ This kept training aligned with the real environment instead of drifting into a separate offline approximation.
70
+
71
+ ## Why the two-role setup is interesting
72
+
73
+ The most fun part of DebugZero is that it creates its own pressure to improve.
74
+
75
+ If the solver becomes stronger, the proposer has to invent better bugs. If the proposer becomes better at making subtle failures, the solver has to become more precise at repair. That gives us a natural self-play curriculum for debugging.
76
+
77
+ Instead of hand-authoring every training example, we get an environment where challenge and skill can rise together.
78
+
79
+ ## What DebugZero is really trying to test
80
+
81
+ At a deeper level, this project is about whether coding agents can become better debuggers through interaction rather than static supervision alone.
82
+
83
+ We care about questions like:
84
+
85
+ - Can an agent learn to create realistic failure modes?
86
+ - Can it repair bugs without over-editing the program?
87
+ - Can self-play produce a useful curriculum for code reasoning?
88
+ - Can reward grounded in execution and tests teach something that static datasets miss?
89
+
90
+ DebugZero is our attempt at turning those questions into something concrete and measurable.
91
+
92
+ ## Links
93
+
94
+ - Hugging Face Space: https://the-fool-09-debugzero.hf.space
95
+ - Hugging Face project page: https://huggingface.co/spaces/The-Fool-09/debugZero
96
+ - Training notebook: `notebooks/train_colab_updated_1.ipynb`
97
+
98
+ In short, DebugZero is not just a benchmark where a model writes code. It is an environment where the model learns from failure, creates new failure cases, and improves through the loop of breaking and repairing programs. That is the behavior we wanted to surface, and that is what we trained for.
Dockerfile CHANGED
@@ -47,6 +47,7 @@ RUN apt-get update && \
47
 
48
  COPY --from=builder /app/.venv /app/.venv
49
  COPY --from=builder /app/env /app/env
 
50
 
51
  ENV PATH="/app/.venv/bin:$PATH"
52
  ENV PYTHONPATH="/app/env:$PYTHONPATH"
 
47
 
48
  COPY --from=builder /app/.venv /app/.venv
49
  COPY --from=builder /app/env /app/env
50
+ COPY --from=builder /app/env/README.md /app/README.md
51
 
52
  ENV PATH="/app/.venv/bin:$PATH"
53
  ENV PYTHONPATH="/app/env:$PYTHONPATH"
client.py CHANGED
@@ -69,29 +69,29 @@ class DebugzeroEnv(
69
  Args:
70
  payload: JSON response data from server
71
 
72
- Returns:
73
- StepResult with DebugzeroObservation
74
- """
75
- obs_data = payload.get("observation", {})
76
- reward_value = payload.get("reward", obs_data.get("reward"))
77
- done_value = payload.get("done", obs_data.get("done", False))
78
- observation = DebugzeroObservation(
79
- role_next=obs_data.get("role_next", "proposer"),
80
- current_code=obs_data.get("current_code", ""),
81
- execution_result=obs_data.get("execution_result", ""),
82
- tests_passed=obs_data.get("tests_passed", False),
83
- syntax_error=obs_data.get("syntax_error", False),
84
- score=obs_data.get("score", 0.0),
85
- done=done_value,
86
- reward=reward_value,
87
- metadata=obs_data.get("metadata", {}),
88
- )
89
-
90
- return StepResult(
91
- observation=observation,
92
- reward=reward_value,
93
- done=done_value,
94
- )
95
 
96
  def _parse_state(self, payload: Dict) -> DebugzeroState:
97
  """
 
69
  Args:
70
  payload: JSON response data from server
71
 
72
+ Returns:
73
+ StepResult with DebugzeroObservation
74
+ """
75
+ obs_data = payload.get("observation", {})
76
+ reward_value = payload.get("reward", obs_data.get("reward"))
77
+ done_value = payload.get("done", obs_data.get("done", False))
78
+ observation = DebugzeroObservation(
79
+ role_next=obs_data.get("role_next", "proposer"),
80
+ current_code=obs_data.get("current_code", ""),
81
+ execution_result=obs_data.get("execution_result", ""),
82
+ tests_passed=obs_data.get("tests_passed", False),
83
+ syntax_error=obs_data.get("syntax_error", False),
84
+ score=obs_data.get("score", 0.0),
85
+ done=done_value,
86
+ reward=reward_value,
87
+ metadata=obs_data.get("metadata", {}),
88
+ )
89
+
90
+ return StepResult(
91
+ observation=observation,
92
+ reward=reward_value,
93
+ done=done_value,
94
+ )
95
 
96
  def _parse_state(self, payload: Dict) -> DebugzeroState:
97
  """
eval/api_baseline.py CHANGED
@@ -1,346 +1,346 @@
1
- import asyncio
2
- import inspect
3
- import json
4
- import os
5
- import sys
6
- import textwrap
7
- from typing import Any, Optional
8
-
9
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
10
-
11
- from dotenv import load_dotenv
12
- from openai import OpenAI
13
-
14
- from client import DebugzeroEnv
15
- from models import DebugzeroAction
16
-
17
- load_dotenv()
18
-
19
-
20
- API_BASE_URL = os.getenv("API_BASE_URL") or os.getenv("OPENAI_BASE_URL", "https://openrouter.ai/api/v1")
21
- MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("OPENAI_MODEL", "meta-llama/llama-3.1-8b-instruct")
22
- API_KEY = os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN")
23
- ENV_URL = os.getenv("DEBUGZERO_ENV_URL", "http://localhost:8000")
24
-
25
- NUM_EPISODES = int(os.getenv("NUM_EPISODES", "6"))
26
- MAX_STEPS = int(os.getenv("MAX_STEPS", "8"))
27
- PROPOSER_TEMPERATURE = float(os.getenv("PROPOSER_TEMPERATURE", "0.7"))
28
- SOLVER_TEMPERATURE = float(os.getenv("SOLVER_TEMPERATURE", "0.2"))
29
- MAX_TOKENS = int(os.getenv("MAX_TOKENS", "1024"))
30
- BUG_FOCUS = os.getenv("DEBUGZERO_BUG_FOCUS")
31
-
32
-
33
- def extract_python_code(text: str) -> str:
34
- content = (text or "").strip()
35
- if content.startswith("```"):
36
- content = content.split("\n", 1)[-1]
37
- if content.endswith("```"):
38
- content = content.rsplit("\n", 1)[0]
39
- return content.strip()
40
-
41
-
42
- def summarize_error(text: str, max_chars: int = 220) -> str:
43
- cleaned = " ".join(text.strip().split())
44
- if not cleaned:
45
- return "null"
46
- if len(cleaned) <= max_chars:
47
- return cleaned
48
- return cleaned[: max_chars - 3].rstrip() + "..."
49
-
50
-
51
- def extract_env_error(result: Any) -> Optional[str]:
52
- for attr in ("last_action_error", "error", "message"):
53
- if hasattr(result, attr):
54
- value = getattr(result, attr)
55
- if value:
56
- return str(value)
57
-
58
- obs = getattr(result, "observation", None)
59
- if obs is None:
60
- return None
61
-
62
- for attr in ("last_action_error", "error"):
63
- if hasattr(obs, attr):
64
- value = getattr(obs, attr)
65
- if value:
66
- return str(value)
67
-
68
- execution_result = getattr(obs, "execution_result", "")
69
- if isinstance(execution_result, str) and execution_result:
70
- if getattr(obs, "syntax_error", False):
71
- return summarize_error(execution_result)
72
- if execution_result.startswith("Unsafe import detected."):
73
- return execution_result
74
- if not getattr(obs, "tests_passed", False):
75
- return summarize_error(execution_result)
76
-
77
- return None
78
-
79
-
80
- def compact_action_string(role: str, code: str) -> str:
81
- return json.dumps({"role": role, "code": code}, separators=(",", ":"), ensure_ascii=False)
82
-
83
-
84
- def build_prompt(obs_dict: dict[str, Any], history: list[str]) -> str:
85
- role = str(obs_dict.get("role_next", "proposer"))
86
- current_code = str(obs_dict.get("current_code", ""))
87
- execution_result = str(obs_dict.get("execution_result", ""))
88
- tests_passed = bool(obs_dict.get("tests_passed", False))
89
- syntax_error = bool(obs_dict.get("syntax_error", False))
90
- metadata = obs_dict.get("metadata", {}) or {}
91
- seed_id = metadata.get("seed_id", "unknown")
92
- history_block = "\n".join(history[-4:]) if history else "None"
93
-
94
- if role == "proposer":
95
- focus_line = ""
96
- if BUG_FOCUS:
97
- focus_line = f"- Focus specifically on the `{BUG_FOCUS}` mutation family.\n"
98
- instructions = textwrap.dedent(
99
- f"""
100
- You are the Proposer in a debugging self-play environment.
101
- Return a full Python function with exactly one small logical bug injected.
102
-
103
- Rules:
104
- - Keep the code valid Python.
105
- - Keep the same function signature.
106
- - Preserve the overall structure and formatting as much as possible.
107
- - Make exactly one small local behavioral change.
108
- - Avoid comments, explanations, markdown outside the code block, and broad rewrites.
109
- {focus_line}- Your goal is to make tests fail without creating a syntax error.
110
- """
111
- ).strip()
112
- else:
113
- instructions = textwrap.dedent(
114
- """
115
- You are the Solver in a debugging self-play environment.
116
- Return the full fixed Python function.
117
-
118
- Rules:
119
- - Keep the code valid Python.
120
- - Keep the same function signature.
121
- - Make the smallest correct local fix you can.
122
- - Use the failure output to guide the repair.
123
- - Avoid comments, explanations, markdown outside the code block, and unrelated refactors.
124
- """
125
- ).strip()
126
-
127
- return textwrap.dedent(
128
- f"""
129
- {instructions}
130
-
131
- Current environment state:
132
- - seed_id: {seed_id}
133
- - role_next: {role}
134
- - tests_passed: {tests_passed}
135
- - syntax_error: {syntax_error}
136
-
137
- Current code:
138
- ```python
139
- {current_code}
140
- ```
141
-
142
- Execution result:
143
- {execution_result if execution_result else "None"}
144
-
145
- Previous actions:
146
- {history_block}
147
-
148
- Return only the full Python code inside triple backticks.
149
- """
150
- ).strip()
151
-
152
-
153
- def get_model_code(client: OpenAI, obs_dict: dict[str, Any], history: list[str]) -> str:
154
- role = str(obs_dict.get("role_next", "proposer"))
155
- prompt = build_prompt(obs_dict, history)
156
- temperature = PROPOSER_TEMPERATURE if role == "proposer" else SOLVER_TEMPERATURE
157
-
158
- response = client.chat.completions.create(
159
- model=MODEL_NAME,
160
- messages=[
161
- {"role": "system", "content": "You are an expert Python coder."},
162
- {"role": "user", "content": prompt},
163
- ],
164
- temperature=temperature,
165
- max_tokens=MAX_TOKENS,
166
- )
167
-
168
- return extract_python_code(response.choices[0].message.content or "")
169
-
170
-
171
- async def maybe_await(value: Any) -> Any:
172
- if inspect.isawaitable(value):
173
- return await value
174
- return value
175
-
176
-
177
- async def call_env_method(obj: Any, method_name: str, *args: Any) -> Any:
178
- method = getattr(obj, method_name)
179
- result = method(*args)
180
- return await maybe_await(result)
181
-
182
-
183
- async def make_env() -> Any:
184
- max_retries = 30
185
- for attempt in range(max_retries):
186
- try:
187
- return DebugzeroEnv(base_url=ENV_URL)
188
- except Exception as exc:
189
- print(
190
- f"[SYSTEM ERROR] Env connection to {ENV_URL} failed (attempt {attempt + 1}/{max_retries}): {exc}",
191
- file=sys.stderr,
192
- flush=True,
193
- )
194
- if attempt < max_retries - 1:
195
- await asyncio.sleep(5.0)
196
- else:
197
- raise
198
-
199
-
200
- def print_live_summary(metrics: dict[str, Any]) -> None:
201
- episodes = max(1, int(metrics["episodes"]))
202
- proposer_attempts = max(1, int(metrics["proposer_attempts"]))
203
- solver_attempts = max(1, int(metrics["solver_attempts"]))
204
- rewards = metrics["rewards"]
205
- average_reward = (sum(rewards) / len(rewards)) if rewards else 0.0
206
-
207
- print("\n" + "=" * 80)
208
- print("Live API summary")
209
- print("=" * 80)
210
- print(f"Episode success rate: {metrics['episode_successes'] / episodes:.2%}")
211
- print(f"Proposer syntax rate: {metrics['proposer_syntax_errors'] / proposer_attempts:.2%}")
212
- print(f"Solver syntax rate: {metrics['solver_syntax_errors'] / solver_attempts:.2%}")
213
- print(f"Average step reward: {average_reward:.2f}")
214
- print(f"Average steps/episode: {metrics['total_steps'] / episodes:.2f}")
215
- print(f"Representative success: {metrics['representative_success']}")
216
- print(f"Representative failure: {metrics['representative_failure']}")
217
-
218
-
219
- async def run_live_api_probe() -> dict[str, Any] | None:
220
- if not API_KEY:
221
- print("Skipping live API probe: OPENAI_API_KEY/API_KEY is not set.")
222
- return None
223
- if not MODEL_NAME:
224
- print("Skipping live API probe: OPENAI_MODEL/MODEL_NAME is not set.")
225
- return None
226
-
227
- client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
228
- env = await make_env()
229
-
230
- metrics = {
231
- "episodes": NUM_EPISODES,
232
- "episode_successes": 0,
233
- "proposer_attempts": 0,
234
- "solver_attempts": 0,
235
- "proposer_syntax_errors": 0,
236
- "solver_syntax_errors": 0,
237
- "rewards": [],
238
- "total_steps": 0,
239
- "representative_success": None,
240
- "representative_failure": None,
241
- }
242
-
243
- print("=" * 80)
244
- print("Live API probe")
245
- print("=" * 80)
246
- print(f"API base URL: {API_BASE_URL}")
247
- print(f"Model: {MODEL_NAME}")
248
- print(f"Env URL: {ENV_URL}")
249
-
250
- try:
251
- for episode in range(1, NUM_EPISODES + 1):
252
- result = await call_env_method(env, "reset")
253
- obs = getattr(result, "observation", None)
254
- done = bool(getattr(result, "done", False))
255
- history: list[str] = []
256
- success = False
257
-
258
- seed_id = "unknown"
259
- if obs is not None:
260
- metadata = getattr(obs, "metadata", {}) or {}
261
- seed_id = metadata.get("seed_id", "unknown")
262
-
263
- print(f"\nEpisode {episode}/{NUM_EPISODES} | seed={seed_id}")
264
-
265
- for step in range(1, MAX_STEPS + 1):
266
- if done or obs is None:
267
- break
268
-
269
- obs_dict = obs.model_dump() if hasattr(obs, "model_dump") else obs.dict()
270
- role = str(obs_dict.get("role_next", "proposer"))
271
- if role == "proposer":
272
- metrics["proposer_attempts"] += 1
273
- else:
274
- metrics["solver_attempts"] += 1
275
-
276
- try:
277
- code = await asyncio.to_thread(get_model_code, client, obs_dict, history)
278
- except Exception as exc:
279
- print(f"[SYSTEM ERROR] Model generation failed: {exc}", file=sys.stderr, flush=True)
280
- code = str(obs_dict.get("current_code", ""))
281
-
282
- action = DebugzeroAction(role=role, code=code)
283
- action_str = compact_action_string(role, code)
284
- result = await call_env_method(env, "step", action)
285
- obs = getattr(result, "observation", None)
286
- done = bool(getattr(result, "done", False))
287
- reward = float(getattr(result, "reward", 0.0) or 0.0)
288
- error = extract_env_error(result)
289
-
290
- metrics["rewards"].append(reward)
291
- metrics["total_steps"] += 1
292
-
293
- if obs is not None and getattr(obs, "syntax_error", False):
294
- if role == "proposer":
295
- metrics["proposer_syntax_errors"] += 1
296
- else:
297
- metrics["solver_syntax_errors"] += 1
298
-
299
- print(
300
- f" step={step} role={role} reward={reward:.2f} done={str(done).lower()} error={error or 'null'}",
301
- flush=True,
302
- )
303
- history.append(f"Step {step}: {action_str} -> reward {reward:.2f}")
304
-
305
- if done and obs is not None:
306
- success = bool(getattr(obs, "tests_passed", False)) and not bool(
307
- getattr(obs, "syntax_error", False)
308
- )
309
- if success:
310
- metrics["episode_successes"] += 1
311
- if metrics["representative_success"] is None:
312
- metrics["representative_success"] = {
313
- "seed_id": getattr(obs, "metadata", {}).get("seed_id", "unknown"),
314
- "steps": step,
315
- "reward": reward,
316
- }
317
- elif metrics["representative_failure"] is None:
318
- metrics["representative_failure"] = {
319
- "seed_id": getattr(obs, "metadata", {}).get("seed_id", "unknown"),
320
- "steps": step,
321
- "execution_result": getattr(obs, "execution_result", ""),
322
- }
323
- break
324
-
325
- if not success and metrics["representative_failure"] is None:
326
- failure_seed = seed_id
327
- failure_output = getattr(obs, "execution_result", "") if obs is not None else ""
328
- metrics["representative_failure"] = {
329
- "seed_id": failure_seed,
330
- "steps": min(MAX_STEPS, len(history)),
331
- "execution_result": failure_output,
332
- }
333
-
334
- return metrics
335
- finally:
336
- await call_env_method(env, "close")
337
-
338
-
339
- async def main() -> None:
340
- metrics = await run_live_api_probe()
341
- if metrics is not None:
342
- print_live_summary(metrics)
343
-
344
-
345
- if __name__ == "__main__":
346
- asyncio.run(main())
 
1
+ import asyncio
2
+ import inspect
3
+ import json
4
+ import os
5
+ import sys
6
+ import textwrap
7
+ from typing import Any, Optional
8
+
9
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
10
+
11
+ from dotenv import load_dotenv
12
+ from openai import OpenAI
13
+
14
+ from client import DebugzeroEnv
15
+ from models import DebugzeroAction
16
+
17
+ load_dotenv()
18
+
19
+
20
+ API_BASE_URL = os.getenv("API_BASE_URL") or os.getenv("OPENAI_BASE_URL", "https://openrouter.ai/api/v1")
21
+ MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("OPENAI_MODEL", "meta-llama/llama-3.1-8b-instruct")
22
+ API_KEY = os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN")
23
+ ENV_URL = os.getenv("DEBUGZERO_ENV_URL", "http://localhost:8000")
24
+
25
+ NUM_EPISODES = int(os.getenv("NUM_EPISODES", "6"))
26
+ MAX_STEPS = int(os.getenv("MAX_STEPS", "8"))
27
+ PROPOSER_TEMPERATURE = float(os.getenv("PROPOSER_TEMPERATURE", "0.7"))
28
+ SOLVER_TEMPERATURE = float(os.getenv("SOLVER_TEMPERATURE", "0.2"))
29
+ MAX_TOKENS = int(os.getenv("MAX_TOKENS", "1024"))
30
+ BUG_FOCUS = os.getenv("DEBUGZERO_BUG_FOCUS")
31
+
32
+
33
+ def extract_python_code(text: str) -> str:
34
+ content = (text or "").strip()
35
+ if content.startswith("```"):
36
+ content = content.split("\n", 1)[-1]
37
+ if content.endswith("```"):
38
+ content = content.rsplit("\n", 1)[0]
39
+ return content.strip()
40
+
41
+
42
+ def summarize_error(text: str, max_chars: int = 220) -> str:
43
+ cleaned = " ".join(text.strip().split())
44
+ if not cleaned:
45
+ return "null"
46
+ if len(cleaned) <= max_chars:
47
+ return cleaned
48
+ return cleaned[: max_chars - 3].rstrip() + "..."
49
+
50
+
51
+ def extract_env_error(result: Any) -> Optional[str]:
52
+ for attr in ("last_action_error", "error", "message"):
53
+ if hasattr(result, attr):
54
+ value = getattr(result, attr)
55
+ if value:
56
+ return str(value)
57
+
58
+ obs = getattr(result, "observation", None)
59
+ if obs is None:
60
+ return None
61
+
62
+ for attr in ("last_action_error", "error"):
63
+ if hasattr(obs, attr):
64
+ value = getattr(obs, attr)
65
+ if value:
66
+ return str(value)
67
+
68
+ execution_result = getattr(obs, "execution_result", "")
69
+ if isinstance(execution_result, str) and execution_result:
70
+ if getattr(obs, "syntax_error", False):
71
+ return summarize_error(execution_result)
72
+ if execution_result.startswith("Unsafe import detected."):
73
+ return execution_result
74
+ if not getattr(obs, "tests_passed", False):
75
+ return summarize_error(execution_result)
76
+
77
+ return None
78
+
79
+
80
+ def compact_action_string(role: str, code: str) -> str:
81
+ return json.dumps({"role": role, "code": code}, separators=(",", ":"), ensure_ascii=False)
82
+
83
+
84
+ def build_prompt(obs_dict: dict[str, Any], history: list[str]) -> str:
85
+ role = str(obs_dict.get("role_next", "proposer"))
86
+ current_code = str(obs_dict.get("current_code", ""))
87
+ execution_result = str(obs_dict.get("execution_result", ""))
88
+ tests_passed = bool(obs_dict.get("tests_passed", False))
89
+ syntax_error = bool(obs_dict.get("syntax_error", False))
90
+ metadata = obs_dict.get("metadata", {}) or {}
91
+ seed_id = metadata.get("seed_id", "unknown")
92
+ history_block = "\n".join(history[-4:]) if history else "None"
93
+
94
+ if role == "proposer":
95
+ focus_line = ""
96
+ if BUG_FOCUS:
97
+ focus_line = f"- Focus specifically on the `{BUG_FOCUS}` mutation family.\n"
98
+ instructions = textwrap.dedent(
99
+ f"""
100
+ You are the Proposer in a debugging self-play environment.
101
+ Return a full Python function with exactly one small logical bug injected.
102
+
103
+ Rules:
104
+ - Keep the code valid Python.
105
+ - Keep the same function signature.
106
+ - Preserve the overall structure and formatting as much as possible.
107
+ - Make exactly one small local behavioral change.
108
+ - Avoid comments, explanations, markdown outside the code block, and broad rewrites.
109
+ {focus_line}- Your goal is to make tests fail without creating a syntax error.
110
+ """
111
+ ).strip()
112
+ else:
113
+ instructions = textwrap.dedent(
114
+ """
115
+ You are the Solver in a debugging self-play environment.
116
+ Return the full fixed Python function.
117
+
118
+ Rules:
119
+ - Keep the code valid Python.
120
+ - Keep the same function signature.
121
+ - Make the smallest correct local fix you can.
122
+ - Use the failure output to guide the repair.
123
+ - Avoid comments, explanations, markdown outside the code block, and unrelated refactors.
124
+ """
125
+ ).strip()
126
+
127
+ return textwrap.dedent(
128
+ f"""
129
+ {instructions}
130
+
131
+ Current environment state:
132
+ - seed_id: {seed_id}
133
+ - role_next: {role}
134
+ - tests_passed: {tests_passed}
135
+ - syntax_error: {syntax_error}
136
+
137
+ Current code:
138
+ ```python
139
+ {current_code}
140
+ ```
141
+
142
+ Execution result:
143
+ {execution_result if execution_result else "None"}
144
+
145
+ Previous actions:
146
+ {history_block}
147
+
148
+ Return only the full Python code inside triple backticks.
149
+ """
150
+ ).strip()
151
+
152
+
153
+ def get_model_code(client: OpenAI, obs_dict: dict[str, Any], history: list[str]) -> str:
154
+ role = str(obs_dict.get("role_next", "proposer"))
155
+ prompt = build_prompt(obs_dict, history)
156
+ temperature = PROPOSER_TEMPERATURE if role == "proposer" else SOLVER_TEMPERATURE
157
+
158
+ response = client.chat.completions.create(
159
+ model=MODEL_NAME,
160
+ messages=[
161
+ {"role": "system", "content": "You are an expert Python coder."},
162
+ {"role": "user", "content": prompt},
163
+ ],
164
+ temperature=temperature,
165
+ max_tokens=MAX_TOKENS,
166
+ )
167
+
168
+ return extract_python_code(response.choices[0].message.content or "")
169
+
170
+
171
+ async def maybe_await(value: Any) -> Any:
172
+ if inspect.isawaitable(value):
173
+ return await value
174
+ return value
175
+
176
+
177
+ async def call_env_method(obj: Any, method_name: str, *args: Any) -> Any:
178
+ method = getattr(obj, method_name)
179
+ result = method(*args)
180
+ return await maybe_await(result)
181
+
182
+
183
+ async def make_env() -> Any:
184
+ max_retries = 30
185
+ for attempt in range(max_retries):
186
+ try:
187
+ return DebugzeroEnv(base_url=ENV_URL)
188
+ except Exception as exc:
189
+ print(
190
+ f"[SYSTEM ERROR] Env connection to {ENV_URL} failed (attempt {attempt + 1}/{max_retries}): {exc}",
191
+ file=sys.stderr,
192
+ flush=True,
193
+ )
194
+ if attempt < max_retries - 1:
195
+ await asyncio.sleep(5.0)
196
+ else:
197
+ raise
198
+
199
+
200
+ def print_live_summary(metrics: dict[str, Any]) -> None:
201
+ episodes = max(1, int(metrics["episodes"]))
202
+ proposer_attempts = max(1, int(metrics["proposer_attempts"]))
203
+ solver_attempts = max(1, int(metrics["solver_attempts"]))
204
+ rewards = metrics["rewards"]
205
+ average_reward = (sum(rewards) / len(rewards)) if rewards else 0.0
206
+
207
+ print("\n" + "=" * 80)
208
+ print("Live API summary")
209
+ print("=" * 80)
210
+ print(f"Episode success rate: {metrics['episode_successes'] / episodes:.2%}")
211
+ print(f"Proposer syntax rate: {metrics['proposer_syntax_errors'] / proposer_attempts:.2%}")
212
+ print(f"Solver syntax rate: {metrics['solver_syntax_errors'] / solver_attempts:.2%}")
213
+ print(f"Average step reward: {average_reward:.2f}")
214
+ print(f"Average steps/episode: {metrics['total_steps'] / episodes:.2f}")
215
+ print(f"Representative success: {metrics['representative_success']}")
216
+ print(f"Representative failure: {metrics['representative_failure']}")
217
+
218
+
219
+ async def run_live_api_probe() -> dict[str, Any] | None:
220
+ if not API_KEY:
221
+ print("Skipping live API probe: OPENAI_API_KEY/API_KEY is not set.")
222
+ return None
223
+ if not MODEL_NAME:
224
+ print("Skipping live API probe: OPENAI_MODEL/MODEL_NAME is not set.")
225
+ return None
226
+
227
+ client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
228
+ env = await make_env()
229
+
230
+ metrics = {
231
+ "episodes": NUM_EPISODES,
232
+ "episode_successes": 0,
233
+ "proposer_attempts": 0,
234
+ "solver_attempts": 0,
235
+ "proposer_syntax_errors": 0,
236
+ "solver_syntax_errors": 0,
237
+ "rewards": [],
238
+ "total_steps": 0,
239
+ "representative_success": None,
240
+ "representative_failure": None,
241
+ }
242
+
243
+ print("=" * 80)
244
+ print("Live API probe")
245
+ print("=" * 80)
246
+ print(f"API base URL: {API_BASE_URL}")
247
+ print(f"Model: {MODEL_NAME}")
248
+ print(f"Env URL: {ENV_URL}")
249
+
250
+ try:
251
+ for episode in range(1, NUM_EPISODES + 1):
252
+ result = await call_env_method(env, "reset")
253
+ obs = getattr(result, "observation", None)
254
+ done = bool(getattr(result, "done", False))
255
+ history: list[str] = []
256
+ success = False
257
+
258
+ seed_id = "unknown"
259
+ if obs is not None:
260
+ metadata = getattr(obs, "metadata", {}) or {}
261
+ seed_id = metadata.get("seed_id", "unknown")
262
+
263
+ print(f"\nEpisode {episode}/{NUM_EPISODES} | seed={seed_id}")
264
+
265
+ for step in range(1, MAX_STEPS + 1):
266
+ if done or obs is None:
267
+ break
268
+
269
+ obs_dict = obs.model_dump() if hasattr(obs, "model_dump") else obs.dict()
270
+ role = str(obs_dict.get("role_next", "proposer"))
271
+ if role == "proposer":
272
+ metrics["proposer_attempts"] += 1
273
+ else:
274
+ metrics["solver_attempts"] += 1
275
+
276
+ try:
277
+ code = await asyncio.to_thread(get_model_code, client, obs_dict, history)
278
+ except Exception as exc:
279
+ print(f"[SYSTEM ERROR] Model generation failed: {exc}", file=sys.stderr, flush=True)
280
+ code = str(obs_dict.get("current_code", ""))
281
+
282
+ action = DebugzeroAction(role=role, code=code)
283
+ action_str = compact_action_string(role, code)
284
+ result = await call_env_method(env, "step", action)
285
+ obs = getattr(result, "observation", None)
286
+ done = bool(getattr(result, "done", False))
287
+ reward = float(getattr(result, "reward", 0.0) or 0.0)
288
+ error = extract_env_error(result)
289
+
290
+ metrics["rewards"].append(reward)
291
+ metrics["total_steps"] += 1
292
+
293
+ if obs is not None and getattr(obs, "syntax_error", False):
294
+ if role == "proposer":
295
+ metrics["proposer_syntax_errors"] += 1
296
+ else:
297
+ metrics["solver_syntax_errors"] += 1
298
+
299
+ print(
300
+ f" step={step} role={role} reward={reward:.2f} done={str(done).lower()} error={error or 'null'}",
301
+ flush=True,
302
+ )
303
+ history.append(f"Step {step}: {action_str} -> reward {reward:.2f}")
304
+
305
+ if done and obs is not None:
306
+ success = bool(getattr(obs, "tests_passed", False)) and not bool(
307
+ getattr(obs, "syntax_error", False)
308
+ )
309
+ if success:
310
+ metrics["episode_successes"] += 1
311
+ if metrics["representative_success"] is None:
312
+ metrics["representative_success"] = {
313
+ "seed_id": getattr(obs, "metadata", {}).get("seed_id", "unknown"),
314
+ "steps": step,
315
+ "reward": reward,
316
+ }
317
+ elif metrics["representative_failure"] is None:
318
+ metrics["representative_failure"] = {
319
+ "seed_id": getattr(obs, "metadata", {}).get("seed_id", "unknown"),
320
+ "steps": step,
321
+ "execution_result": getattr(obs, "execution_result", ""),
322
+ }
323
+ break
324
+
325
+ if not success and metrics["representative_failure"] is None:
326
+ failure_seed = seed_id
327
+ failure_output = getattr(obs, "execution_result", "") if obs is not None else ""
328
+ metrics["representative_failure"] = {
329
+ "seed_id": failure_seed,
330
+ "steps": min(MAX_STEPS, len(history)),
331
+ "execution_result": failure_output,
332
+ }
333
+
334
+ return metrics
335
+ finally:
336
+ await call_env_method(env, "close")
337
+
338
+
339
+ async def main() -> None:
340
+ metrics = await run_live_api_probe()
341
+ if metrics is not None:
342
+ print_live_summary(metrics)
343
+
344
+
345
+ if __name__ == "__main__":
346
+ asyncio.run(main())
inference.py CHANGED
@@ -1,332 +1,332 @@
1
- import asyncio
2
- import inspect
3
- import json
4
- import os
5
- import sys
6
- import textwrap
7
- from typing import Any, List, Optional
8
-
9
- from dotenv import load_dotenv
10
- from openai import OpenAI
11
-
12
- from client import DebugzeroEnv
13
- from models import DebugzeroAction
14
-
15
- load_dotenv()
16
-
17
-
18
- API_BASE_URL = os.getenv("API_BASE_URL") or os.getenv("OPENAI_BASE_URL", "https://openrouter.ai/api/v1")
19
- MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("OPENAI_MODEL", "meta-llama/llama-3.1-8b-instruct")
20
- API_KEY = os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN")
21
-
22
- LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
23
- ENV_URL = os.getenv("DEBUGZERO_ENV_URL", "http://localhost:8000")
24
- TASK_NAME = os.getenv("DEBUGZERO_TASK", "debugging-self-play")
25
- BENCHMARK = os.getenv("DEBUGZERO_BENCHMARK", "debugzero")
26
- BUG_FOCUS = os.getenv("DEBUGZERO_BUG_FOCUS")
27
-
28
- NUM_EPISODES = int(os.getenv("NUM_EPISODES", "3"))
29
- MAX_STEPS = int(os.getenv("MAX_STEPS", "8"))
30
- PROPOSER_TEMPERATURE = float(os.getenv("PROPOSER_TEMPERATURE", "0.7"))
31
- SOLVER_TEMPERATURE = float(os.getenv("SOLVER_TEMPERATURE", "0.2"))
32
- MAX_TOKENS = int(os.getenv("MAX_TOKENS", "1024"))
33
-
34
-
35
- def extract_python_code(text: str) -> str:
36
- content = (text or "").strip()
37
- if content.startswith("```"):
38
- content = content.split("\n", 1)[-1]
39
- if content.endswith("```"):
40
- content = content.rsplit("\n", 1)[0]
41
- return content.strip()
42
-
43
-
44
- def compact_action_string(role: str, code: str) -> str:
45
- obj = {"role": role, "code": code}
46
- return json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
47
-
48
-
49
- def log_start(task: str, env: str, model: str) -> None:
50
- print(f"[START] task={task} env={env} model={model}", flush=True)
51
-
52
-
53
- def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
54
- error_val = error if error is not None else "null"
55
- action_str = action.replace("\n", "\\n")
56
- print(
57
- f"[STEP] step={step} action={action_str} reward={reward:.2f} done={str(done).lower()} error={error_val}",
58
- flush=True,
59
- )
60
-
61
-
62
- def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
63
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
64
- print(
65
- f"[END] success={str(success).lower()} steps={steps} score={score:.4f} rewards={rewards_str}",
66
- flush=True,
67
- )
68
-
69
-
70
- def summarize_error(text: str, max_chars: int = 220) -> str:
71
- cleaned = " ".join(text.strip().split())
72
- if not cleaned:
73
- return "null"
74
- if len(cleaned) <= max_chars:
75
- return cleaned
76
- return cleaned[: max_chars - 3].rstrip() + "..."
77
-
78
-
79
- def extract_env_error(result: Any) -> Optional[str]:
80
- for attr in ("last_action_error", "error", "message"):
81
- if hasattr(result, attr):
82
- value = getattr(result, attr)
83
- if value:
84
- return str(value)
85
-
86
- obs = getattr(result, "observation", None)
87
- if obs is None:
88
- return None
89
-
90
- for attr in ("last_action_error", "error"):
91
- if hasattr(obs, attr):
92
- value = getattr(obs, attr)
93
- if value:
94
- return str(value)
95
-
96
- execution_result = getattr(obs, "execution_result", "")
97
- if isinstance(execution_result, str) and execution_result:
98
- if getattr(obs, "syntax_error", False):
99
- return summarize_error(execution_result)
100
- if execution_result.startswith("Unsafe import detected."):
101
- return execution_result
102
- if not getattr(obs, "tests_passed", False):
103
- return summarize_error(execution_result)
104
-
105
- return None
106
-
107
-
108
- def build_prompt(obs_dict: dict[str, Any], history: List[str]) -> str:
109
- role = str(obs_dict.get("role_next", "proposer"))
110
- current_code = str(obs_dict.get("current_code", ""))
111
- execution_result = str(obs_dict.get("execution_result", ""))
112
- tests_passed = bool(obs_dict.get("tests_passed", False))
113
- syntax_error = bool(obs_dict.get("syntax_error", False))
114
- metadata = obs_dict.get("metadata", {}) or {}
115
- seed_id = metadata.get("seed_id", "unknown")
116
- history_block = "\n".join(history[-4:]) if history else "None"
117
-
118
- if role == "proposer":
119
- focus_line = ""
120
- if BUG_FOCUS:
121
- focus_line = f"- Focus specifically on the `{BUG_FOCUS}` mutation family.\n"
122
- task_block = textwrap.dedent(
123
- f"""
124
- You are the Proposer in a debugging self-play environment.
125
- Return a full Python function with exactly one small logical bug injected.
126
-
127
- Rules:
128
- - Keep the code valid Python.
129
- - Keep the same function signature.
130
- - Preserve the overall structure and formatting as much as possible.
131
- - Make exactly one small local behavioral change.
132
- - Avoid comments, explanations, markdown outside the code block, and broad rewrites.
133
- {focus_line}- Your goal is to make tests fail without creating a syntax error.
134
- """
135
- ).strip()
136
- else:
137
- task_block = textwrap.dedent(
138
- """
139
- You are the Solver in a debugging self-play environment.
140
- Return the full fixed Python function.
141
-
142
- Rules:
143
- - Keep the code valid Python.
144
- - Keep the same function signature.
145
- - Make the smallest correct local fix you can.
146
- - Use the failure output to guide the repair.
147
- - Avoid comments, explanations, markdown outside the code block, and unrelated refactors.
148
- """
149
- ).strip()
150
-
151
- return textwrap.dedent(
152
- f"""
153
- {task_block}
154
-
155
- Current environment state:
156
- - seed_id: {seed_id}
157
- - role_next: {role}
158
- - tests_passed: {tests_passed}
159
- - syntax_error: {syntax_error}
160
-
161
- Current code:
162
- ```python
163
- {current_code}
164
- ```
165
-
166
- Execution result:
167
- {execution_result if execution_result else "None"}
168
-
169
- Previous actions:
170
- {history_block}
171
-
172
- Return only the full Python code inside triple backticks.
173
- """
174
- ).strip()
175
-
176
-
177
- def get_model_code(client: OpenAI, obs_dict: dict[str, Any], history: List[str]) -> str:
178
- role = str(obs_dict.get("role_next", "proposer"))
179
- prompt = build_prompt(obs_dict, history)
180
- temperature = PROPOSER_TEMPERATURE if role == "proposer" else SOLVER_TEMPERATURE
181
-
182
- response = client.chat.completions.create(
183
- model=MODEL_NAME,
184
- messages=[
185
- {"role": "system", "content": "You are an expert Python coder."},
186
- {"role": "user", "content": prompt},
187
- ],
188
- temperature=temperature,
189
- max_tokens=MAX_TOKENS,
190
- )
191
-
192
- return extract_python_code(response.choices[0].message.content or "")
193
-
194
-
195
- async def maybe_await(value: Any) -> Any:
196
- if inspect.isawaitable(value):
197
- return await value
198
- return value
199
-
200
-
201
- async def call_env_method(obj: Any, method_name: str, *args: Any) -> Any:
202
- method = getattr(obj, method_name)
203
- result = method(*args)
204
- return await maybe_await(result)
205
-
206
-
207
- async def make_env() -> Any:
208
- max_retries = 30
209
-
210
- if LOCAL_IMAGE_NAME:
211
- for attempt in range(max_retries):
212
- try:
213
- env = DebugzeroEnv.from_docker_image(LOCAL_IMAGE_NAME)
214
- return await maybe_await(env)
215
- except Exception as exc:
216
- print(
217
- f"[SYSTEM ERROR] Failed to start Docker environment (attempt {attempt + 1}/{max_retries}): {exc}",
218
- file=sys.stderr,
219
- flush=True,
220
- )
221
- if attempt < max_retries - 1:
222
- await asyncio.sleep(5.0)
223
- else:
224
- raise
225
-
226
- for attempt in range(max_retries):
227
- try:
228
- return DebugzeroEnv(base_url=ENV_URL)
229
- except Exception as exc:
230
- print(
231
- f"[SYSTEM ERROR] Env connection to {ENV_URL} failed (attempt {attempt + 1}/{max_retries}): {exc}",
232
- file=sys.stderr,
233
- flush=True,
234
- )
235
- if attempt < max_retries - 1:
236
- await asyncio.sleep(5.0)
237
- else:
238
- raise
239
-
240
-
241
- async def main() -> None:
242
- if not API_KEY:
243
- print("[SYSTEM ERROR] Missing API key. Set API_KEY, OPENAI_API_KEY, or HF_TOKEN.", file=sys.stderr, flush=True)
244
- return
245
-
246
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
247
- env = None
248
-
249
- try:
250
- env = await make_env()
251
-
252
- for _episode in range(1, NUM_EPISODES + 1):
253
- history: List[str] = []
254
- rewards: List[float] = []
255
- steps_taken = 0
256
- score = 0.0
257
- success = False
258
-
259
- log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
260
-
261
- try:
262
- result = await call_env_method(env, "reset")
263
- done = bool(getattr(result, "done", False))
264
- obs = getattr(result, "observation", None)
265
-
266
- for step in range(1, MAX_STEPS + 1):
267
- if done or obs is None:
268
- break
269
-
270
- obs_dict = obs.model_dump() if hasattr(obs, "model_dump") else obs.dict()
271
- role = str(obs_dict.get("role_next", "proposer"))
272
-
273
- try:
274
- code = await asyncio.to_thread(get_model_code, client, obs_dict, history)
275
- env_action = DebugzeroAction(role=role, code=code)
276
- action_str = compact_action_string(role, code)
277
- except Exception as exc:
278
- print(f"[SYSTEM ERROR] Model generation failed: {exc}", file=sys.stderr, flush=True)
279
- code = obs_dict.get("current_code", "")
280
- env_action = DebugzeroAction(role=role, code=code)
281
- action_str = compact_action_string(role, code)
282
-
283
- result = await call_env_method(env, "step", env_action)
284
- obs = getattr(result, "observation", None)
285
- done = bool(getattr(result, "done", False))
286
- reward = float(getattr(result, "reward", 0.0) or 0.0)
287
-
288
- rewards.append(reward)
289
- steps_taken = step
290
-
291
- error = extract_env_error(result)
292
-
293
- if obs is not None:
294
- score = float(getattr(obs, "score", score) or score)
295
- if done:
296
- success = bool(getattr(obs, "tests_passed", False)) and not bool(
297
- getattr(obs, "syntax_error", False)
298
- )
299
-
300
- score = max(0.0001, min(0.9999, score))
301
- log_step(step=step, action=action_str, reward=reward, done=done, error=error)
302
-
303
- history.append(f"Step {step}: {action_str} -> reward {reward:.2f}")
304
-
305
- score = max(0.0001, min(0.9999, float(score)))
306
-
307
- except Exception as exc:
308
- print(f"[SYSTEM ERROR] {exc}", file=sys.stderr, flush=True)
309
- success = False
310
- finally:
311
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
312
-
313
- except Exception as exc:
314
- print(f"[SYSTEM ERROR] {exc}", file=sys.stderr, flush=True)
315
- finally:
316
- try:
317
- if env is not None and hasattr(env, "close"):
318
- await call_env_method(env, "close")
319
- except Exception:
320
- pass
321
-
322
-
323
- if __name__ == "__main__":
324
- try:
325
- asyncio.run(main())
326
- sys.exit(0)
327
- except Exception as exc:
328
- print(f"[CRITICAL VALIDATION ERROR] {exc}", file=sys.stderr, flush=True)
329
- sys.exit(0)
330
- except BaseException as base_exc:
331
- print(f"[BASE EXCEPTION] {base_exc}", file=sys.stderr, flush=True)
332
- sys.exit(0)
 
1
+ import asyncio
2
+ import inspect
3
+ import json
4
+ import os
5
+ import sys
6
+ import textwrap
7
+ from typing import Any, List, Optional
8
+
9
+ from dotenv import load_dotenv
10
+ from openai import OpenAI
11
+
12
+ from client import DebugzeroEnv
13
+ from models import DebugzeroAction
14
+
15
+ load_dotenv()
16
+
17
+
18
+ API_BASE_URL = os.getenv("API_BASE_URL") or os.getenv("OPENAI_BASE_URL", "https://openrouter.ai/api/v1")
19
+ MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("OPENAI_MODEL", "meta-llama/llama-3.1-8b-instruct")
20
+ API_KEY = os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN")
21
+
22
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
23
+ ENV_URL = os.getenv("DEBUGZERO_ENV_URL", "http://localhost:8000")
24
+ TASK_NAME = os.getenv("DEBUGZERO_TASK", "debugging-self-play")
25
+ BENCHMARK = os.getenv("DEBUGZERO_BENCHMARK", "debugzero")
26
+ BUG_FOCUS = os.getenv("DEBUGZERO_BUG_FOCUS")
27
+
28
+ NUM_EPISODES = int(os.getenv("NUM_EPISODES", "3"))
29
+ MAX_STEPS = int(os.getenv("MAX_STEPS", "8"))
30
+ PROPOSER_TEMPERATURE = float(os.getenv("PROPOSER_TEMPERATURE", "0.7"))
31
+ SOLVER_TEMPERATURE = float(os.getenv("SOLVER_TEMPERATURE", "0.2"))
32
+ MAX_TOKENS = int(os.getenv("MAX_TOKENS", "1024"))
33
+
34
+
35
+ def extract_python_code(text: str) -> str:
36
+ content = (text or "").strip()
37
+ if content.startswith("```"):
38
+ content = content.split("\n", 1)[-1]
39
+ if content.endswith("```"):
40
+ content = content.rsplit("\n", 1)[0]
41
+ return content.strip()
42
+
43
+
44
+ def compact_action_string(role: str, code: str) -> str:
45
+ obj = {"role": role, "code": code}
46
+ return json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
47
+
48
+
49
+ def log_start(task: str, env: str, model: str) -> None:
50
+ print(f"[START] task={task} env={env} model={model}", flush=True)
51
+
52
+
53
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
54
+ error_val = error if error is not None else "null"
55
+ action_str = action.replace("\n", "\\n")
56
+ print(
57
+ f"[STEP] step={step} action={action_str} reward={reward:.2f} done={str(done).lower()} error={error_val}",
58
+ flush=True,
59
+ )
60
+
61
+
62
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
63
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
64
+ print(
65
+ f"[END] success={str(success).lower()} steps={steps} score={score:.4f} rewards={rewards_str}",
66
+ flush=True,
67
+ )
68
+
69
+
70
+ def summarize_error(text: str, max_chars: int = 220) -> str:
71
+ cleaned = " ".join(text.strip().split())
72
+ if not cleaned:
73
+ return "null"
74
+ if len(cleaned) <= max_chars:
75
+ return cleaned
76
+ return cleaned[: max_chars - 3].rstrip() + "..."
77
+
78
+
79
+ def extract_env_error(result: Any) -> Optional[str]:
80
+ for attr in ("last_action_error", "error", "message"):
81
+ if hasattr(result, attr):
82
+ value = getattr(result, attr)
83
+ if value:
84
+ return str(value)
85
+
86
+ obs = getattr(result, "observation", None)
87
+ if obs is None:
88
+ return None
89
+
90
+ for attr in ("last_action_error", "error"):
91
+ if hasattr(obs, attr):
92
+ value = getattr(obs, attr)
93
+ if value:
94
+ return str(value)
95
+
96
+ execution_result = getattr(obs, "execution_result", "")
97
+ if isinstance(execution_result, str) and execution_result:
98
+ if getattr(obs, "syntax_error", False):
99
+ return summarize_error(execution_result)
100
+ if execution_result.startswith("Unsafe import detected."):
101
+ return execution_result
102
+ if not getattr(obs, "tests_passed", False):
103
+ return summarize_error(execution_result)
104
+
105
+ return None
106
+
107
+
108
+ def build_prompt(obs_dict: dict[str, Any], history: List[str]) -> str:
109
+ role = str(obs_dict.get("role_next", "proposer"))
110
+ current_code = str(obs_dict.get("current_code", ""))
111
+ execution_result = str(obs_dict.get("execution_result", ""))
112
+ tests_passed = bool(obs_dict.get("tests_passed", False))
113
+ syntax_error = bool(obs_dict.get("syntax_error", False))
114
+ metadata = obs_dict.get("metadata", {}) or {}
115
+ seed_id = metadata.get("seed_id", "unknown")
116
+ history_block = "\n".join(history[-4:]) if history else "None"
117
+
118
+ if role == "proposer":
119
+ focus_line = ""
120
+ if BUG_FOCUS:
121
+ focus_line = f"- Focus specifically on the `{BUG_FOCUS}` mutation family.\n"
122
+ task_block = textwrap.dedent(
123
+ f"""
124
+ You are the Proposer in a debugging self-play environment.
125
+ Return a full Python function with exactly one small logical bug injected.
126
+
127
+ Rules:
128
+ - Keep the code valid Python.
129
+ - Keep the same function signature.
130
+ - Preserve the overall structure and formatting as much as possible.
131
+ - Make exactly one small local behavioral change.
132
+ - Avoid comments, explanations, markdown outside the code block, and broad rewrites.
133
+ {focus_line}- Your goal is to make tests fail without creating a syntax error.
134
+ """
135
+ ).strip()
136
+ else:
137
+ task_block = textwrap.dedent(
138
+ """
139
+ You are the Solver in a debugging self-play environment.
140
+ Return the full fixed Python function.
141
+
142
+ Rules:
143
+ - Keep the code valid Python.
144
+ - Keep the same function signature.
145
+ - Make the smallest correct local fix you can.
146
+ - Use the failure output to guide the repair.
147
+ - Avoid comments, explanations, markdown outside the code block, and unrelated refactors.
148
+ """
149
+ ).strip()
150
+
151
+ return textwrap.dedent(
152
+ f"""
153
+ {task_block}
154
+
155
+ Current environment state:
156
+ - seed_id: {seed_id}
157
+ - role_next: {role}
158
+ - tests_passed: {tests_passed}
159
+ - syntax_error: {syntax_error}
160
+
161
+ Current code:
162
+ ```python
163
+ {current_code}
164
+ ```
165
+
166
+ Execution result:
167
+ {execution_result if execution_result else "None"}
168
+
169
+ Previous actions:
170
+ {history_block}
171
+
172
+ Return only the full Python code inside triple backticks.
173
+ """
174
+ ).strip()
175
+
176
+
177
+ def get_model_code(client: OpenAI, obs_dict: dict[str, Any], history: List[str]) -> str:
178
+ role = str(obs_dict.get("role_next", "proposer"))
179
+ prompt = build_prompt(obs_dict, history)
180
+ temperature = PROPOSER_TEMPERATURE if role == "proposer" else SOLVER_TEMPERATURE
181
+
182
+ response = client.chat.completions.create(
183
+ model=MODEL_NAME,
184
+ messages=[
185
+ {"role": "system", "content": "You are an expert Python coder."},
186
+ {"role": "user", "content": prompt},
187
+ ],
188
+ temperature=temperature,
189
+ max_tokens=MAX_TOKENS,
190
+ )
191
+
192
+ return extract_python_code(response.choices[0].message.content or "")
193
+
194
+
195
+ async def maybe_await(value: Any) -> Any:
196
+ if inspect.isawaitable(value):
197
+ return await value
198
+ return value
199
+
200
+
201
+ async def call_env_method(obj: Any, method_name: str, *args: Any) -> Any:
202
+ method = getattr(obj, method_name)
203
+ result = method(*args)
204
+ return await maybe_await(result)
205
+
206
+
207
+ async def make_env() -> Any:
208
+ max_retries = 30
209
+
210
+ if LOCAL_IMAGE_NAME:
211
+ for attempt in range(max_retries):
212
+ try:
213
+ env = DebugzeroEnv.from_docker_image(LOCAL_IMAGE_NAME)
214
+ return await maybe_await(env)
215
+ except Exception as exc:
216
+ print(
217
+ f"[SYSTEM ERROR] Failed to start Docker environment (attempt {attempt + 1}/{max_retries}): {exc}",
218
+ file=sys.stderr,
219
+ flush=True,
220
+ )
221
+ if attempt < max_retries - 1:
222
+ await asyncio.sleep(5.0)
223
+ else:
224
+ raise
225
+
226
+ for attempt in range(max_retries):
227
+ try:
228
+ return DebugzeroEnv(base_url=ENV_URL)
229
+ except Exception as exc:
230
+ print(
231
+ f"[SYSTEM ERROR] Env connection to {ENV_URL} failed (attempt {attempt + 1}/{max_retries}): {exc}",
232
+ file=sys.stderr,
233
+ flush=True,
234
+ )
235
+ if attempt < max_retries - 1:
236
+ await asyncio.sleep(5.0)
237
+ else:
238
+ raise
239
+
240
+
241
+ async def main() -> None:
242
+ if not API_KEY:
243
+ print("[SYSTEM ERROR] Missing API key. Set API_KEY, OPENAI_API_KEY, or HF_TOKEN.", file=sys.stderr, flush=True)
244
+ return
245
+
246
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
247
+ env = None
248
+
249
+ try:
250
+ env = await make_env()
251
+
252
+ for _episode in range(1, NUM_EPISODES + 1):
253
+ history: List[str] = []
254
+ rewards: List[float] = []
255
+ steps_taken = 0
256
+ score = 0.0
257
+ success = False
258
+
259
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
260
+
261
+ try:
262
+ result = await call_env_method(env, "reset")
263
+ done = bool(getattr(result, "done", False))
264
+ obs = getattr(result, "observation", None)
265
+
266
+ for step in range(1, MAX_STEPS + 1):
267
+ if done or obs is None:
268
+ break
269
+
270
+ obs_dict = obs.model_dump() if hasattr(obs, "model_dump") else obs.dict()
271
+ role = str(obs_dict.get("role_next", "proposer"))
272
+
273
+ try:
274
+ code = await asyncio.to_thread(get_model_code, client, obs_dict, history)
275
+ env_action = DebugzeroAction(role=role, code=code)
276
+ action_str = compact_action_string(role, code)
277
+ except Exception as exc:
278
+ print(f"[SYSTEM ERROR] Model generation failed: {exc}", file=sys.stderr, flush=True)
279
+ code = obs_dict.get("current_code", "")
280
+ env_action = DebugzeroAction(role=role, code=code)
281
+ action_str = compact_action_string(role, code)
282
+
283
+ result = await call_env_method(env, "step", env_action)
284
+ obs = getattr(result, "observation", None)
285
+ done = bool(getattr(result, "done", False))
286
+ reward = float(getattr(result, "reward", 0.0) or 0.0)
287
+
288
+ rewards.append(reward)
289
+ steps_taken = step
290
+
291
+ error = extract_env_error(result)
292
+
293
+ if obs is not None:
294
+ score = float(getattr(obs, "score", score) or score)
295
+ if done:
296
+ success = bool(getattr(obs, "tests_passed", False)) and not bool(
297
+ getattr(obs, "syntax_error", False)
298
+ )
299
+
300
+ score = max(0.0001, min(0.9999, score))
301
+ log_step(step=step, action=action_str, reward=reward, done=done, error=error)
302
+
303
+ history.append(f"Step {step}: {action_str} -> reward {reward:.2f}")
304
+
305
+ score = max(0.0001, min(0.9999, float(score)))
306
+
307
+ except Exception as exc:
308
+ print(f"[SYSTEM ERROR] {exc}", file=sys.stderr, flush=True)
309
+ success = False
310
+ finally:
311
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
312
+
313
+ except Exception as exc:
314
+ print(f"[SYSTEM ERROR] {exc}", file=sys.stderr, flush=True)
315
+ finally:
316
+ try:
317
+ if env is not None and hasattr(env, "close"):
318
+ await call_env_method(env, "close")
319
+ except Exception:
320
+ pass
321
+
322
+
323
+ if __name__ == "__main__":
324
+ try:
325
+ asyncio.run(main())
326
+ sys.exit(0)
327
+ except Exception as exc:
328
+ print(f"[CRITICAL VALIDATION ERROR] {exc}", file=sys.stderr, flush=True)
329
+ sys.exit(0)
330
+ except BaseException as base_exc:
331
+ print(f"[BASE EXCEPTION] {base_exc}", file=sys.stderr, flush=True)
332
+ sys.exit(0)
models.py CHANGED
@@ -22,15 +22,15 @@ class DebugzeroAction(Action):
22
  code: str = Field(..., description="Code injected (by proposer) or fixed (by solver)")
23
 
24
 
25
- class DebugzeroObservation(Observation):
26
- """Observation from the DebugZero environment following sandbox execution."""
27
-
28
- role_next: str = Field(default="proposer", description="The role supposed to play next")
29
- current_code: str = Field(default="", description="The current state of the python code")
30
- execution_result: str = Field(default="", description="Result of evaluating tests in the sandbox")
31
- tests_passed: bool = Field(default=False, description="Whether the tests passed")
32
- syntax_error: bool = Field(default=False, description="Whether the code had a parse/syntax error")
33
- score: float = Field(default=0.0, description="Episode progress score in the range [0.0, 1.0]")
34
 
35
  class DebugzeroState(State):
36
  """State for the DebugZero environment, extending default state with seed context."""
 
22
  code: str = Field(..., description="Code injected (by proposer) or fixed (by solver)")
23
 
24
 
25
+ class DebugzeroObservation(Observation):
26
+ """Observation from the DebugZero environment following sandbox execution."""
27
+
28
+ role_next: str = Field(default="proposer", description="The role supposed to play next")
29
+ current_code: str = Field(default="", description="The current state of the python code")
30
+ execution_result: str = Field(default="", description="Result of evaluating tests in the sandbox")
31
+ tests_passed: bool = Field(default=False, description="Whether the tests passed")
32
+ syntax_error: bool = Field(default=False, description="Whether the code had a parse/syntax error")
33
+ score: float = Field(default=0.0, description="Episode progress score in the range [0.0, 1.0]")
34
 
35
  class DebugzeroState(State):
36
  """State for the DebugZero environment, extending default state with seed context."""
server/debugZero_environment.py CHANGED
@@ -1,224 +1,224 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- DebugZero Environment Implementation for adversarial bug-fixing self-play.
9
- """
10
-
11
- from __future__ import annotations
12
-
13
- from uuid import uuid4
14
-
15
- from openenv.core.env_server.interfaces import Environment
16
-
17
- try:
18
- from ..models import DebugzeroAction, DebugzeroObservation, DebugzeroState
19
- from .tasks import SEED_BANK, SeedSpec
20
- except ImportError:
21
- from models import DebugzeroAction, DebugzeroObservation, DebugzeroState
22
- from server.tasks import SEED_BANK, SeedSpec
23
-
24
- try:
25
- from .bug_injector import infer_bug_operator
26
- from .graders import (
27
- compute_ast_distance,
28
- compute_proposer_reward,
29
- compute_solver_reward,
30
- is_effectively_unchanged,
31
- )
32
- from .executor import execute_code
33
- except ImportError:
34
- from bug_injector import infer_bug_operator
35
- from graders import (
36
- compute_ast_distance,
37
- compute_proposer_reward,
38
- compute_solver_reward,
39
- is_effectively_unchanged,
40
- )
41
- from executor import execute_code
42
-
43
-
44
- class DebugzeroEnvironment(Environment):
45
- """
46
- Dual-role DebugZero Environment wrapping a Python sandbox execution
47
- for Proposer bug injection and Solver bug fixing.
48
- """
49
-
50
- SUPPORTS_CONCURRENT_SESSIONS: bool = True
51
-
52
- def __init__(self):
53
- self._reset_count = 0
54
- self._current_seed = SEED_BANK[0]
55
- self._current_bug_operator: str | None = None
56
- self._current_score = 0.0
57
- self._proposer_created_bug = False
58
- self._state = self._build_state(self._current_seed)
59
-
60
- def reset(self) -> DebugzeroObservation:
61
- seed = SEED_BANK[self._reset_count % len(SEED_BANK)]
62
- self._reset_count += 1
63
- self._current_seed = seed
64
- self._current_bug_operator = None
65
- self._current_score = 0.0
66
- self._proposer_created_bug = False
67
- self._state = self._build_state(seed)
68
-
69
- return self._build_observation(
70
- role_next="proposer",
71
- execution_result="",
72
- tests_passed=True,
73
- syntax_error=False,
74
- done=False,
75
- reward=0.0,
76
- score=0.0,
77
- )
78
-
79
- def step(self, action: DebugzeroAction) -> DebugzeroObservation: # type: ignore[override]
80
- self._state.step_count += 1
81
-
82
- tests = self._current_seed.test
83
-
84
- if action.role == "proposer":
85
- self._state.current_code = action.code
86
- result = execute_code(self._state.current_code, tests)
87
- self._state.role_turn = "solver"
88
- reward, score = self._proposer_step_feedback(action.code, result)
89
-
90
- return self._build_observation(
91
- role_next="solver",
92
- execution_result=self._truncate_execution_output(result.output),
93
- tests_passed=result.passed,
94
- syntax_error=result.syntax_error,
95
- done=False,
96
- reward=reward,
97
- score=score,
98
- )
99
-
100
- if action.role == "solver":
101
- self._state.current_code = action.code
102
- result = execute_code(self._state.current_code, tests)
103
- self._state.role_turn = "end"
104
- reward, score = self._solver_step_feedback(result)
105
-
106
- return self._build_observation(
107
- role_next="proposer",
108
- execution_result=self._truncate_execution_output(result.output),
109
- tests_passed=result.passed,
110
- syntax_error=result.syntax_error,
111
- done=True,
112
- reward=reward,
113
- score=score,
114
- )
115
-
116
- self._current_score = 0.0
117
- self._proposer_created_bug = False
118
- return self._build_observation(
119
- role_next="end",
120
- execution_result="",
121
- tests_passed=False,
122
- syntax_error=False,
123
- done=True,
124
- reward=0.0,
125
- score=0.0,
126
- )
127
-
128
- @property
129
- def state(self) -> DebugzeroState:
130
- return self._state
131
-
132
- def _build_state(self, seed: SeedSpec) -> DebugzeroState:
133
- return DebugzeroState(
134
- episode_id=str(uuid4()),
135
- step_count=0,
136
- seed_id=seed.seed_id,
137
- original_code=seed.original_code,
138
- current_code=seed.original_code,
139
- role_turn="proposer",
140
- )
141
-
142
- def _build_observation(
143
- self,
144
- *,
145
- role_next: str,
146
- execution_result: str,
147
- tests_passed: bool,
148
- syntax_error: bool,
149
- done: bool,
150
- reward: float,
151
- score: float,
152
- ) -> DebugzeroObservation:
153
- self._current_score = score
154
- return DebugzeroObservation(
155
- role_next=role_next,
156
- current_code=self._state.current_code,
157
- execution_result=execution_result,
158
- tests_passed=tests_passed,
159
- syntax_error=syntax_error,
160
- score=score,
161
- done=done,
162
- reward=reward,
163
- metadata=self._observation_metadata(),
164
- )
165
-
166
- def _proposer_step_feedback(self, candidate_code: str, result: object) -> tuple[float, float]:
167
- original_code = self._state.original_code
168
- execution_output = getattr(result, "output", "") or ""
169
- syntax_error = bool(getattr(result, "syntax_error", False))
170
- tests_passed = bool(getattr(result, "passed", False))
171
- unsafe_code = execution_output.startswith("Unsafe import detected.")
172
-
173
- unchanged_code = is_effectively_unchanged(original_code, candidate_code)
174
- changed_but_passing = (not unchanged_code) and tests_passed and (not syntax_error)
175
- plausibility_score = 0.0 if syntax_error else compute_ast_distance(original_code, candidate_code)
176
-
177
- reward = compute_proposer_reward(
178
- {
179
- "seed_id": self._state.seed_id,
180
- "tests_passed": tests_passed,
181
- "syntax_error": syntax_error,
182
- "unsafe_code": unsafe_code,
183
- "unchanged_code": unchanged_code,
184
- "changed_but_passing": changed_but_passing,
185
- "plausibility_score": plausibility_score,
186
- }
187
- )
188
-
189
- valid_bug = (not tests_passed) and (not syntax_error) and (not unsafe_code)
190
- self._proposer_created_bug = valid_bug
191
- self._current_bug_operator = infer_bug_operator(original_code, candidate_code) if valid_bug else None
192
- score = 0.5 if valid_bug else 0.0
193
- return reward, score
194
-
195
- def _solver_step_feedback(self, result: object) -> tuple[float, float]:
196
- execution_output = getattr(result, "output", "") or ""
197
- syntax_error = bool(getattr(result, "syntax_error", False))
198
- tests_passed = bool(getattr(result, "passed", False))
199
- unsafe_code = execution_output.startswith("Unsafe import detected.")
200
-
201
- reward = compute_solver_reward(
202
- {
203
- "seed_id": self._state.seed_id,
204
- "tests_passed": tests_passed,
205
- "syntax_error": syntax_error,
206
- "unsafe_code": unsafe_code,
207
- }
208
- )
209
-
210
- solved = tests_passed and (not syntax_error) and (not unsafe_code)
211
- score = 1.0 if solved else (0.5 if self._proposer_created_bug else 0.0)
212
- return reward, score
213
-
214
- def _truncate_execution_output(self, output: str) -> str:
215
- return output[:500] if output else ""
216
-
217
- def _observation_metadata(self) -> dict[str, str]:
218
- metadata = {
219
- "seed_id": self._state.seed_id,
220
- "original_code": self._state.original_code,
221
- }
222
- if self._current_bug_operator:
223
- metadata["bug_operator"] = self._current_bug_operator
224
- return metadata
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ DebugZero Environment Implementation for adversarial bug-fixing self-play.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from uuid import uuid4
14
+
15
+ from openenv.core.env_server.interfaces import Environment
16
+
17
+ try:
18
+ from ..models import DebugzeroAction, DebugzeroObservation, DebugzeroState
19
+ from .tasks import SEED_BANK, SeedSpec
20
+ except ImportError:
21
+ from models import DebugzeroAction, DebugzeroObservation, DebugzeroState
22
+ from server.tasks import SEED_BANK, SeedSpec
23
+
24
+ try:
25
+ from .bug_injector import infer_bug_operator
26
+ from .graders import (
27
+ compute_ast_distance,
28
+ compute_proposer_reward,
29
+ compute_solver_reward,
30
+ is_effectively_unchanged,
31
+ )
32
+ from .executor import execute_code
33
+ except ImportError:
34
+ from bug_injector import infer_bug_operator
35
+ from graders import (
36
+ compute_ast_distance,
37
+ compute_proposer_reward,
38
+ compute_solver_reward,
39
+ is_effectively_unchanged,
40
+ )
41
+ from executor import execute_code
42
+
43
+
44
+ class DebugzeroEnvironment(Environment):
45
+ """
46
+ Dual-role DebugZero Environment wrapping a Python sandbox execution
47
+ for Proposer bug injection and Solver bug fixing.
48
+ """
49
+
50
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
51
+
52
+ def __init__(self):
53
+ self._reset_count = 0
54
+ self._current_seed = SEED_BANK[0]
55
+ self._current_bug_operator: str | None = None
56
+ self._current_score = 0.0
57
+ self._proposer_created_bug = False
58
+ self._state = self._build_state(self._current_seed)
59
+
60
+ def reset(self) -> DebugzeroObservation:
61
+ seed = SEED_BANK[self._reset_count % len(SEED_BANK)]
62
+ self._reset_count += 1
63
+ self._current_seed = seed
64
+ self._current_bug_operator = None
65
+ self._current_score = 0.0
66
+ self._proposer_created_bug = False
67
+ self._state = self._build_state(seed)
68
+
69
+ return self._build_observation(
70
+ role_next="proposer",
71
+ execution_result="",
72
+ tests_passed=True,
73
+ syntax_error=False,
74
+ done=False,
75
+ reward=0.0,
76
+ score=0.0,
77
+ )
78
+
79
+ def step(self, action: DebugzeroAction) -> DebugzeroObservation: # type: ignore[override]
80
+ self._state.step_count += 1
81
+
82
+ tests = self._current_seed.test
83
+
84
+ if action.role == "proposer":
85
+ self._state.current_code = action.code
86
+ result = execute_code(self._state.current_code, tests)
87
+ self._state.role_turn = "solver"
88
+ reward, score = self._proposer_step_feedback(action.code, result)
89
+
90
+ return self._build_observation(
91
+ role_next="solver",
92
+ execution_result=self._truncate_execution_output(result.output),
93
+ tests_passed=result.passed,
94
+ syntax_error=result.syntax_error,
95
+ done=False,
96
+ reward=reward,
97
+ score=score,
98
+ )
99
+
100
+ if action.role == "solver":
101
+ self._state.current_code = action.code
102
+ result = execute_code(self._state.current_code, tests)
103
+ self._state.role_turn = "end"
104
+ reward, score = self._solver_step_feedback(result)
105
+
106
+ return self._build_observation(
107
+ role_next="proposer",
108
+ execution_result=self._truncate_execution_output(result.output),
109
+ tests_passed=result.passed,
110
+ syntax_error=result.syntax_error,
111
+ done=True,
112
+ reward=reward,
113
+ score=score,
114
+ )
115
+
116
+ self._current_score = 0.0
117
+ self._proposer_created_bug = False
118
+ return self._build_observation(
119
+ role_next="end",
120
+ execution_result="",
121
+ tests_passed=False,
122
+ syntax_error=False,
123
+ done=True,
124
+ reward=0.0,
125
+ score=0.0,
126
+ )
127
+
128
+ @property
129
+ def state(self) -> DebugzeroState:
130
+ return self._state
131
+
132
+ def _build_state(self, seed: SeedSpec) -> DebugzeroState:
133
+ return DebugzeroState(
134
+ episode_id=str(uuid4()),
135
+ step_count=0,
136
+ seed_id=seed.seed_id,
137
+ original_code=seed.original_code,
138
+ current_code=seed.original_code,
139
+ role_turn="proposer",
140
+ )
141
+
142
+ def _build_observation(
143
+ self,
144
+ *,
145
+ role_next: str,
146
+ execution_result: str,
147
+ tests_passed: bool,
148
+ syntax_error: bool,
149
+ done: bool,
150
+ reward: float,
151
+ score: float,
152
+ ) -> DebugzeroObservation:
153
+ self._current_score = score
154
+ return DebugzeroObservation(
155
+ role_next=role_next,
156
+ current_code=self._state.current_code,
157
+ execution_result=execution_result,
158
+ tests_passed=tests_passed,
159
+ syntax_error=syntax_error,
160
+ score=score,
161
+ done=done,
162
+ reward=reward,
163
+ metadata=self._observation_metadata(),
164
+ )
165
+
166
+ def _proposer_step_feedback(self, candidate_code: str, result: object) -> tuple[float, float]:
167
+ original_code = self._state.original_code
168
+ execution_output = getattr(result, "output", "") or ""
169
+ syntax_error = bool(getattr(result, "syntax_error", False))
170
+ tests_passed = bool(getattr(result, "passed", False))
171
+ unsafe_code = execution_output.startswith("Unsafe import detected.")
172
+
173
+ unchanged_code = is_effectively_unchanged(original_code, candidate_code)
174
+ changed_but_passing = (not unchanged_code) and tests_passed and (not syntax_error)
175
+ plausibility_score = 0.0 if syntax_error else compute_ast_distance(original_code, candidate_code)
176
+
177
+ reward = compute_proposer_reward(
178
+ {
179
+ "seed_id": self._state.seed_id,
180
+ "tests_passed": tests_passed,
181
+ "syntax_error": syntax_error,
182
+ "unsafe_code": unsafe_code,
183
+ "unchanged_code": unchanged_code,
184
+ "changed_but_passing": changed_but_passing,
185
+ "plausibility_score": plausibility_score,
186
+ }
187
+ )
188
+
189
+ valid_bug = (not tests_passed) and (not syntax_error) and (not unsafe_code)
190
+ self._proposer_created_bug = valid_bug
191
+ self._current_bug_operator = infer_bug_operator(original_code, candidate_code) if valid_bug else None
192
+ score = 0.5 if valid_bug else 0.0
193
+ return reward, score
194
+
195
+ def _solver_step_feedback(self, result: object) -> tuple[float, float]:
196
+ execution_output = getattr(result, "output", "") or ""
197
+ syntax_error = bool(getattr(result, "syntax_error", False))
198
+ tests_passed = bool(getattr(result, "passed", False))
199
+ unsafe_code = execution_output.startswith("Unsafe import detected.")
200
+
201
+ reward = compute_solver_reward(
202
+ {
203
+ "seed_id": self._state.seed_id,
204
+ "tests_passed": tests_passed,
205
+ "syntax_error": syntax_error,
206
+ "unsafe_code": unsafe_code,
207
+ }
208
+ )
209
+
210
+ solved = tests_passed and (not syntax_error) and (not unsafe_code)
211
+ score = 1.0 if solved else (0.5 if self._proposer_created_bug else 0.0)
212
+ return reward, score
213
+
214
+ def _truncate_execution_output(self, output: str) -> str:
215
+ return output[:500] if output else ""
216
+
217
+ def _observation_metadata(self) -> dict[str, str]:
218
+ metadata = {
219
+ "seed_id": self._state.seed_id,
220
+ "original_code": self._state.original_code,
221
+ }
222
+ if self._current_bug_operator:
223
+ metadata["bug_operator"] = self._current_bug_operator
224
+ return metadata
server/graders.py CHANGED
@@ -60,25 +60,25 @@ def compute_ast_distance(original_code: str, mutated_code: str) -> float:
60
  return 0.0
61
 
62
 
63
- def compute_proposer_reward(meta: dict) -> float:
64
- if meta.get("syntax_error", False):
65
- return -0.5
66
-
67
- if meta.get("unsafe_code", False):
68
- return -0.5
69
-
70
- if meta.get("unchanged_code", False):
71
- return 0.0
72
-
73
- if meta.get("tests_passed", True):
74
- return 0.0
75
-
76
- if meta.get("changed_but_passing", False):
77
- return -0.1
78
-
79
- plausibility_bonus = meta.get("plausibility_score", 0.0)
80
- learnability_bonus = 0.0
81
- solve_rate = get_solve_rate(meta["seed_id"])
82
  if 0.2 <= solve_rate <= 0.8:
83
  learnability_bonus = 1.0
84
 
 
60
  return 0.0
61
 
62
 
63
+ def compute_proposer_reward(meta: dict) -> float:
64
+ if meta.get("syntax_error", False):
65
+ return -0.5
66
+
67
+ if meta.get("unsafe_code", False):
68
+ return -0.5
69
+
70
+ if meta.get("unchanged_code", False):
71
+ return 0.0
72
+
73
+ if meta.get("tests_passed", True):
74
+ return 0.0
75
+
76
+ if meta.get("changed_but_passing", False):
77
+ return -0.1
78
+
79
+ plausibility_bonus = meta.get("plausibility_score", 0.0)
80
+ learnability_bonus = 0.0
81
+ solve_rate = get_solve_rate(meta["seed_id"])
82
  if 0.2 <= solve_rate <= 0.8:
83
  learnability_bonus = 1.0
84
 
server/tasks.py CHANGED
@@ -111,11 +111,11 @@ SEED_BANK = (
111
  "check(count_nonempty)\n"
112
  ),
113
  ),
114
- SeedSpec(
115
- seed_id="DebugZero/5",
116
- entrypoint="running_max",
117
- prompt="def running_max(values: list[int]) -> int:",
118
- canonical_solution=(
119
  " best = values[0]\n"
120
  " for idx in range(1, len(values)):\n"
121
  " if values[idx] > best:\n"
@@ -127,93 +127,93 @@ SEED_BANK = (
127
  " assert candidate([3]) == 3\n"
128
  " assert candidate([3, 1, 5, 2]) == 5\n"
129
  " assert candidate([-1, -4, -2]) == -1\n"
130
- " assert candidate([0, 0, 0]) == 0\n\n"
131
- "check(running_max)\n"
132
- ),
133
- ),
134
- SeedSpec(
135
- seed_id="DebugZero/6",
136
- entrypoint="first_index_of",
137
- prompt="def first_index_of(values: list[int], target: int) -> int:",
138
- canonical_solution=(
139
- " for idx, value in enumerate(values):\n"
140
- " if value == target:\n"
141
- " return idx\n"
142
- " return -1\n"
143
- ),
144
- test=(
145
- "def check(candidate):\n"
146
- " assert candidate([], 3) == -1\n"
147
- " assert candidate([1, 2, 3], 1) == 0\n"
148
- " assert candidate([1, 2, 3], 3) == 2\n"
149
- " assert candidate([5, 7, 5, 7], 7) == 1\n"
150
- " assert candidate([9, 8], 4) == -1\n\n"
151
- "check(first_index_of)\n"
152
- ),
153
- ),
154
- SeedSpec(
155
- seed_id="DebugZero/7",
156
- entrypoint="drop_last",
157
- prompt="def drop_last(values: list[int]) -> list[int]:",
158
- canonical_solution=(
159
- " if not values:\n"
160
- " return []\n"
161
- " return values[:-1]\n"
162
- ),
163
- test=(
164
- "def check(candidate):\n"
165
- " assert candidate([]) == []\n"
166
- " assert candidate([1]) == []\n"
167
- " assert candidate([1, 2]) == [1]\n"
168
- " assert candidate([1, 2, 3, 4]) == [1, 2, 3]\n"
169
- " assert candidate([7, 7, 7]) == [7, 7]\n\n"
170
- "check(drop_last)\n"
171
- ),
172
- ),
173
- SeedSpec(
174
- seed_id="DebugZero/8",
175
- entrypoint="count_greater_than",
176
- prompt="def count_greater_than(values: list[int], threshold: int) -> int:",
177
- canonical_solution=(
178
- " total = 0\n"
179
- " for value in values:\n"
180
- " if value > threshold:\n"
181
- " total += 1\n"
182
- " return total\n"
183
- ),
184
- test=(
185
- "def check(candidate):\n"
186
- " assert candidate([], 1) == 0\n"
187
- " assert candidate([1, 2, 3], 2) == 1\n"
188
- " assert candidate([4, 5, 6], 3) == 3\n"
189
- " assert candidate([0, -1, 2, 2], 1) == 2\n"
190
- " assert candidate([5, 5, 5], 5) == 0\n\n"
191
- "check(count_greater_than)\n"
192
- ),
193
- ),
194
- SeedSpec(
195
- seed_id="DebugZero/9",
196
- entrypoint="prefix_sums",
197
- prompt="def prefix_sums(values: list[int]) -> list[int]:",
198
- canonical_solution=(
199
- " total = 0\n"
200
- " result = []\n"
201
- " for value in values:\n"
202
- " total += value\n"
203
- " result.append(total)\n"
204
- " return result\n"
205
- ),
206
- test=(
207
- "def check(candidate):\n"
208
- " assert candidate([]) == []\n"
209
- " assert candidate([3]) == [3]\n"
210
- " assert candidate([1, 2, 3]) == [1, 3, 6]\n"
211
- " assert candidate([2, -1, 4]) == [2, 1, 5]\n"
212
- " assert candidate([0, 0, 0]) == [0, 0, 0]\n\n"
213
- "check(prefix_sums)\n"
214
- ),
215
- ),
216
- )
217
 
218
  SEED_BY_ID = {seed.seed_id: seed for seed in SEED_BANK}
219
 
 
111
  "check(count_nonempty)\n"
112
  ),
113
  ),
114
+ SeedSpec(
115
+ seed_id="DebugZero/5",
116
+ entrypoint="running_max",
117
+ prompt="def running_max(values: list[int]) -> int:",
118
+ canonical_solution=(
119
  " best = values[0]\n"
120
  " for idx in range(1, len(values)):\n"
121
  " if values[idx] > best:\n"
 
127
  " assert candidate([3]) == 3\n"
128
  " assert candidate([3, 1, 5, 2]) == 5\n"
129
  " assert candidate([-1, -4, -2]) == -1\n"
130
+ " assert candidate([0, 0, 0]) == 0\n\n"
131
+ "check(running_max)\n"
132
+ ),
133
+ ),
134
+ SeedSpec(
135
+ seed_id="DebugZero/6",
136
+ entrypoint="first_index_of",
137
+ prompt="def first_index_of(values: list[int], target: int) -> int:",
138
+ canonical_solution=(
139
+ " for idx, value in enumerate(values):\n"
140
+ " if value == target:\n"
141
+ " return idx\n"
142
+ " return -1\n"
143
+ ),
144
+ test=(
145
+ "def check(candidate):\n"
146
+ " assert candidate([], 3) == -1\n"
147
+ " assert candidate([1, 2, 3], 1) == 0\n"
148
+ " assert candidate([1, 2, 3], 3) == 2\n"
149
+ " assert candidate([5, 7, 5, 7], 7) == 1\n"
150
+ " assert candidate([9, 8], 4) == -1\n\n"
151
+ "check(first_index_of)\n"
152
+ ),
153
+ ),
154
+ SeedSpec(
155
+ seed_id="DebugZero/7",
156
+ entrypoint="drop_last",
157
+ prompt="def drop_last(values: list[int]) -> list[int]:",
158
+ canonical_solution=(
159
+ " if not values:\n"
160
+ " return []\n"
161
+ " return values[:-1]\n"
162
+ ),
163
+ test=(
164
+ "def check(candidate):\n"
165
+ " assert candidate([]) == []\n"
166
+ " assert candidate([1]) == []\n"
167
+ " assert candidate([1, 2]) == [1]\n"
168
+ " assert candidate([1, 2, 3, 4]) == [1, 2, 3]\n"
169
+ " assert candidate([7, 7, 7]) == [7, 7]\n\n"
170
+ "check(drop_last)\n"
171
+ ),
172
+ ),
173
+ SeedSpec(
174
+ seed_id="DebugZero/8",
175
+ entrypoint="count_greater_than",
176
+ prompt="def count_greater_than(values: list[int], threshold: int) -> int:",
177
+ canonical_solution=(
178
+ " total = 0\n"
179
+ " for value in values:\n"
180
+ " if value > threshold:\n"
181
+ " total += 1\n"
182
+ " return total\n"
183
+ ),
184
+ test=(
185
+ "def check(candidate):\n"
186
+ " assert candidate([], 1) == 0\n"
187
+ " assert candidate([1, 2, 3], 2) == 1\n"
188
+ " assert candidate([4, 5, 6], 3) == 3\n"
189
+ " assert candidate([0, -1, 2, 2], 1) == 2\n"
190
+ " assert candidate([5, 5, 5], 5) == 0\n\n"
191
+ "check(count_greater_than)\n"
192
+ ),
193
+ ),
194
+ SeedSpec(
195
+ seed_id="DebugZero/9",
196
+ entrypoint="prefix_sums",
197
+ prompt="def prefix_sums(values: list[int]) -> list[int]:",
198
+ canonical_solution=(
199
+ " total = 0\n"
200
+ " result = []\n"
201
+ " for value in values:\n"
202
+ " total += value\n"
203
+ " result.append(total)\n"
204
+ " return result\n"
205
+ ),
206
+ test=(
207
+ "def check(candidate):\n"
208
+ " assert candidate([]) == []\n"
209
+ " assert candidate([3]) == [3]\n"
210
+ " assert candidate([1, 2, 3]) == [1, 3, 6]\n"
211
+ " assert candidate([2, -1, 4]) == [2, 1, 5]\n"
212
+ " assert candidate([0, 0, 0]) == [0, 0, 0]\n\n"
213
+ "check(prefix_sums)\n"
214
+ ),
215
+ ),
216
+ )
217
 
218
  SEED_BY_ID = {seed.seed_id: seed for seed in SEED_BANK}
219
 
validate-submission.sh CHANGED
@@ -1,185 +1,185 @@
1
- #!/usr/bin/env bash
2
- #
3
- # validate-submission.sh — OpenEnv Submission Validator
4
- #
5
- # Checks that your HF Space is live, Docker image builds, and openenv validate passes.
6
- #
7
- # Prerequisites:
8
- # - Docker: https://docs.docker.com/get-docker/
9
- # - openenv-core: pip install openenv-core
10
- # - curl (usually pre-installed)
11
- #
12
- # Run:
13
- # curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
14
- #
15
- # Or download and run locally:
16
- # chmod +x validate-submission.sh
17
- # ./validate-submission.sh <ping_url> [repo_dir]
18
- #
19
- # Arguments:
20
- # ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
21
- # repo_dir Path to your repo (default: current directory)
22
- #
23
- # Examples:
24
- # ./validate-submission.sh https://my-team.hf.space
25
- # ./validate-submission.sh https://my-team.hf.space ./my-repo
26
- #
27
-
28
- set -uo pipefail
29
-
30
- DOCKER_BUILD_TIMEOUT=600
31
- if [ -t 1 ]; then
32
- RED='\033[0;31m'
33
- GREEN='\033[0;32m'
34
- YELLOW='\033[1;33m'
35
- BOLD='\033[1m'
36
- NC='\033[0m'
37
- else
38
- RED='' GREEN='' YELLOW='' BOLD='' NC=''
39
- fi
40
-
41
- run_with_timeout() {
42
- local secs="$1"; shift
43
- if command -v timeout &>/dev/null; then
44
- timeout "$secs" "$@"
45
- elif command -v gtimeout &>/dev/null; then
46
- gtimeout "$secs" "$@"
47
- else
48
- "$@" &
49
- local pid=$!
50
- ( sleep "$secs" && kill "$pid" 2>/dev/null ) &
51
- local watcher=$!
52
- wait "$pid" 2>/dev/null
53
- local rc=$?
54
- kill "$watcher" 2>/dev/null
55
- wait "$watcher" 2>/dev/null
56
- return $rc
57
- fi
58
- }
59
-
60
- portable_mktemp() {
61
- local prefix="${1:-validate}"
62
- mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
63
- }
64
-
65
- CLEANUP_FILES=()
66
- cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
67
- trap cleanup EXIT
68
-
69
- PING_URL="${1:-}"
70
- REPO_DIR="${2:-.}"
71
-
72
- if [ -z "$PING_URL" ]; then
73
- printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
74
- printf "\n"
75
- printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
76
- printf " repo_dir Path to your repo (default: current directory)\n"
77
- exit 1
78
- fi
79
-
80
- if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
81
- printf "Error: directory '%s' not found\n" "${2:-.}"
82
- exit 1
83
- fi
84
- PING_URL="${PING_URL%/}"
85
- export PING_URL
86
- PASS=0
87
-
88
- log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
89
- pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
90
- fail() { log "${RED}FAILED${NC} -- $1"; }
91
- hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
92
- stop_at() {
93
- printf "\n"
94
- printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
95
- exit 1
96
- }
97
-
98
- printf "\n"
99
- printf "${BOLD}========================================${NC}\n"
100
- printf "${BOLD} OpenEnv Submission Validator${NC}\n"
101
- printf "${BOLD}========================================${NC}\n"
102
- log "Repo: $REPO_DIR"
103
- log "Ping URL: $PING_URL"
104
- printf "\n"
105
-
106
- log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
107
-
108
- CURL_OUTPUT=$(portable_mktemp "validate-curl")
109
- CLEANUP_FILES+=("$CURL_OUTPUT")
110
- HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
111
- -H "Content-Type: application/json" -d '{}' \
112
- "$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
113
-
114
- if [ "$HTTP_CODE" = "200" ]; then
115
- pass "HF Space is live and responds to /reset"
116
- elif [ "$HTTP_CODE" = "000" ]; then
117
- fail "HF Space not reachable (connection failed or timed out)"
118
- hint "Check your network connection and that the Space is running."
119
- hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
120
- stop_at "Step 1"
121
- else
122
- fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
123
- hint "Make sure your Space is running and the URL is correct."
124
- hint "Try opening $PING_URL in your browser first."
125
- stop_at "Step 1"
126
- fi
127
-
128
- log "${BOLD}Step 2/3: Running docker build${NC} ..."
129
-
130
- if ! command -v docker &>/dev/null; then
131
- fail "docker command not found"
132
- hint "Install Docker: https://docs.docker.com/get-docker/"
133
- stop_at "Step 2"
134
- fi
135
-
136
- if [ -f "$REPO_DIR/Dockerfile" ]; then
137
- DOCKER_CONTEXT="$REPO_DIR"
138
- elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
139
- DOCKER_CONTEXT="$REPO_DIR/server"
140
- else
141
- fail "No Dockerfile found in repo root or server/ directory"
142
- stop_at "Step 2"
143
- fi
144
-
145
- log " Found Dockerfile in $DOCKER_CONTEXT"
146
-
147
- BUILD_OK=false
148
- BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
149
-
150
- if [ "$BUILD_OK" = true ]; then
151
- pass "Docker build succeeded"
152
- else
153
- fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
154
- printf "%s\n" "$BUILD_OUTPUT" | tail -20
155
- stop_at "Step 2"
156
- fi
157
-
158
- log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
159
-
160
- if ! command -v openenv &>/dev/null; then
161
- fail "openenv command not found"
162
- hint "Install it: pip install openenv-core"
163
- stop_at "Step 3"
164
- fi
165
-
166
- VALIDATE_OK=false
167
- VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
168
-
169
- if [ "$VALIDATE_OK" = true ]; then
170
- pass "openenv validate passed"
171
- [ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
172
- else
173
- fail "openenv validate failed"
174
- printf "%s\n" "$VALIDATE_OUTPUT"
175
- stop_at "Step 3"
176
- fi
177
-
178
- printf "\n"
179
- printf "${BOLD}========================================${NC}\n"
180
- printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
181
- printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
182
- printf "${BOLD}========================================${NC}\n"
183
- printf "\n"
184
-
185
  exit 0
 
1
+ #!/usr/bin/env bash
2
+ #
3
+ # validate-submission.sh — OpenEnv Submission Validator
4
+ #
5
+ # Checks that your HF Space is live, Docker image builds, and openenv validate passes.
6
+ #
7
+ # Prerequisites:
8
+ # - Docker: https://docs.docker.com/get-docker/
9
+ # - openenv-core: pip install openenv-core
10
+ # - curl (usually pre-installed)
11
+ #
12
+ # Run:
13
+ # curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
14
+ #
15
+ # Or download and run locally:
16
+ # chmod +x validate-submission.sh
17
+ # ./validate-submission.sh <ping_url> [repo_dir]
18
+ #
19
+ # Arguments:
20
+ # ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
21
+ # repo_dir Path to your repo (default: current directory)
22
+ #
23
+ # Examples:
24
+ # ./validate-submission.sh https://my-team.hf.space
25
+ # ./validate-submission.sh https://my-team.hf.space ./my-repo
26
+ #
27
+
28
+ set -uo pipefail
29
+
30
+ DOCKER_BUILD_TIMEOUT=600
31
+ if [ -t 1 ]; then
32
+ RED='\033[0;31m'
33
+ GREEN='\033[0;32m'
34
+ YELLOW='\033[1;33m'
35
+ BOLD='\033[1m'
36
+ NC='\033[0m'
37
+ else
38
+ RED='' GREEN='' YELLOW='' BOLD='' NC=''
39
+ fi
40
+
41
+ run_with_timeout() {
42
+ local secs="$1"; shift
43
+ if command -v timeout &>/dev/null; then
44
+ timeout "$secs" "$@"
45
+ elif command -v gtimeout &>/dev/null; then
46
+ gtimeout "$secs" "$@"
47
+ else
48
+ "$@" &
49
+ local pid=$!
50
+ ( sleep "$secs" && kill "$pid" 2>/dev/null ) &
51
+ local watcher=$!
52
+ wait "$pid" 2>/dev/null
53
+ local rc=$?
54
+ kill "$watcher" 2>/dev/null
55
+ wait "$watcher" 2>/dev/null
56
+ return $rc
57
+ fi
58
+ }
59
+
60
+ portable_mktemp() {
61
+ local prefix="${1:-validate}"
62
+ mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
63
+ }
64
+
65
+ CLEANUP_FILES=()
66
+ cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
67
+ trap cleanup EXIT
68
+
69
+ PING_URL="${1:-}"
70
+ REPO_DIR="${2:-.}"
71
+
72
+ if [ -z "$PING_URL" ]; then
73
+ printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
74
+ printf "\n"
75
+ printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
76
+ printf " repo_dir Path to your repo (default: current directory)\n"
77
+ exit 1
78
+ fi
79
+
80
+ if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
81
+ printf "Error: directory '%s' not found\n" "${2:-.}"
82
+ exit 1
83
+ fi
84
+ PING_URL="${PING_URL%/}"
85
+ export PING_URL
86
+ PASS=0
87
+
88
+ log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
89
+ pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
90
+ fail() { log "${RED}FAILED${NC} -- $1"; }
91
+ hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
92
+ stop_at() {
93
+ printf "\n"
94
+ printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
95
+ exit 1
96
+ }
97
+
98
+ printf "\n"
99
+ printf "${BOLD}========================================${NC}\n"
100
+ printf "${BOLD} OpenEnv Submission Validator${NC}\n"
101
+ printf "${BOLD}========================================${NC}\n"
102
+ log "Repo: $REPO_DIR"
103
+ log "Ping URL: $PING_URL"
104
+ printf "\n"
105
+
106
+ log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
107
+
108
+ CURL_OUTPUT=$(portable_mktemp "validate-curl")
109
+ CLEANUP_FILES+=("$CURL_OUTPUT")
110
+ HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
111
+ -H "Content-Type: application/json" -d '{}' \
112
+ "$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
113
+
114
+ if [ "$HTTP_CODE" = "200" ]; then
115
+ pass "HF Space is live and responds to /reset"
116
+ elif [ "$HTTP_CODE" = "000" ]; then
117
+ fail "HF Space not reachable (connection failed or timed out)"
118
+ hint "Check your network connection and that the Space is running."
119
+ hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
120
+ stop_at "Step 1"
121
+ else
122
+ fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
123
+ hint "Make sure your Space is running and the URL is correct."
124
+ hint "Try opening $PING_URL in your browser first."
125
+ stop_at "Step 1"
126
+ fi
127
+
128
+ log "${BOLD}Step 2/3: Running docker build${NC} ..."
129
+
130
+ if ! command -v docker &>/dev/null; then
131
+ fail "docker command not found"
132
+ hint "Install Docker: https://docs.docker.com/get-docker/"
133
+ stop_at "Step 2"
134
+ fi
135
+
136
+ if [ -f "$REPO_DIR/Dockerfile" ]; then
137
+ DOCKER_CONTEXT="$REPO_DIR"
138
+ elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
139
+ DOCKER_CONTEXT="$REPO_DIR/server"
140
+ else
141
+ fail "No Dockerfile found in repo root or server/ directory"
142
+ stop_at "Step 2"
143
+ fi
144
+
145
+ log " Found Dockerfile in $DOCKER_CONTEXT"
146
+
147
+ BUILD_OK=false
148
+ BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
149
+
150
+ if [ "$BUILD_OK" = true ]; then
151
+ pass "Docker build succeeded"
152
+ else
153
+ fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
154
+ printf "%s\n" "$BUILD_OUTPUT" | tail -20
155
+ stop_at "Step 2"
156
+ fi
157
+
158
+ log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
159
+
160
+ if ! command -v openenv &>/dev/null; then
161
+ fail "openenv command not found"
162
+ hint "Install it: pip install openenv-core"
163
+ stop_at "Step 3"
164
+ fi
165
+
166
+ VALIDATE_OK=false
167
+ VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
168
+
169
+ if [ "$VALIDATE_OK" = true ]; then
170
+ pass "openenv validate passed"
171
+ [ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
172
+ else
173
+ fail "openenv validate failed"
174
+ printf "%s\n" "$VALIDATE_OUTPUT"
175
+ stop_at "Step 3"
176
+ fi
177
+
178
+ printf "\n"
179
+ printf "${BOLD}========================================${NC}\n"
180
+ printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
181
+ printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
182
+ printf "${BOLD}========================================${NC}\n"
183
+ printf "\n"
184
+
185
  exit 0