Spaces:
Sleeping
Sleeping
databoysu commited on
Commit ·
2e11c6a
1
Parent(s): 9882200
TraceFix-RL v1
Browse files- .gitignore +2 -1
- CLAUDE.md +2 -2
- README.md +41 -5
- __init__.py +3 -9
- client.py +4 -62
- context.py +3 -56
- environment.py +10 -120
- inference.py +8 -8
- models.py +1 -1
- openenv.yaml +1 -1
- pyproject.toml +6 -6
- sandbox.py +0 -41
- server/__init__.py +3 -9
- server/app.py +6 -13
- server/{swe_gym_environment.py → tracefix_rl_environment.py} +5 -21
.gitignore
CHANGED
|
@@ -2,4 +2,5 @@
|
|
| 2 |
.agents
|
| 3 |
.env
|
| 4 |
uv.lock
|
| 5 |
-
claude.md
|
|
|
|
|
|
| 2 |
.agents
|
| 3 |
.env
|
| 4 |
uv.lock
|
| 5 |
+
claude.md
|
| 6 |
+
__pycache__/
|
CLAUDE.md
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# CLAUDE.md —
|
| 2 |
|
| 3 |
Codebase knowledge for AI assistants. Read before making changes.
|
| 4 |
|
|
@@ -259,7 +259,7 @@ Config from `os.getenv`:
|
|
| 259 |
|
| 260 |
**Exact stdout log format (regex-parsed by validation judge):**
|
| 261 |
```
|
| 262 |
-
[START] task=<task_name> env=
|
| 263 |
[STEP] step=<n> action=<action_type> reward=<r.rr> done=<true|false> error=<msg|null>
|
| 264 |
[END] success=<true|false> steps=<n> score=<s.sss> rewards=<r1,r2,...,rn>
|
| 265 |
```
|
|
|
|
| 1 |
+
# CLAUDE.md — TraceFix-RL
|
| 2 |
|
| 3 |
Codebase knowledge for AI assistants. Read before making changes.
|
| 4 |
|
|
|
|
| 259 |
|
| 260 |
**Exact stdout log format (regex-parsed by validation judge):**
|
| 261 |
```
|
| 262 |
+
[START] task=<task_name> env=TraceFixRL model=<model_name>
|
| 263 |
[STEP] step=<n> action=<action_type> reward=<r.rr> done=<true|false> error=<msg|null>
|
| 264 |
[END] success=<true|false> steps=<n> score=<s.sss> rewards=<r1,r2,...,rn>
|
| 265 |
```
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 🧑💻
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: cyan
|
|
@@ -13,11 +13,13 @@ tags:
|
|
| 13 |
- software-engineering
|
| 14 |
---
|
| 15 |
|
| 16 |
-
#
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
|
| 22 |
## Core Design
|
| 23 |
|
|
@@ -29,6 +31,40 @@ submitting once all tests pass.
|
|
| 29 |
- Curriculum-ready task sampling:
|
| 30 |
easy/medium/hard buckets with safe random fallback for evaluator runs.
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
## Environment Files
|
| 33 |
|
| 34 |
- `models.py`: action/observation schemas
|
|
|
|
| 1 |
---
|
| 2 |
+
title: TraceFix-RL
|
| 3 |
emoji: 🧑💻
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: cyan
|
|
|
|
| 13 |
- software-engineering
|
| 14 |
---
|
| 15 |
|
| 16 |
+
# TraceFix-RL
|
| 17 |
|
| 18 |
+
TraceFix-RL is an OpenEnv-compatible environment designed to teach agent behavior
|
| 19 |
+
that looks like real software engineering work. Instead of one-shot answers,
|
| 20 |
+
the agent must inspect code, form a hypothesis, run tests, patch the code,
|
| 21 |
+
verify outcomes, and only then submit. The loop rewards disciplined debugging
|
| 22 |
+
and penalizes random edits, forcing the model to learn an engineering workflow.
|
| 23 |
|
| 24 |
## Core Design
|
| 25 |
|
|
|
|
| 31 |
- Curriculum-ready task sampling:
|
| 32 |
easy/medium/hard buckets with safe random fallback for evaluator runs.
|
| 33 |
|
| 34 |
+
## State Machine Training Pattern
|
| 35 |
+
|
| 36 |
+
The environment prompt in `environment.py` encodes a fixed operating pattern
|
| 37 |
+
the agent is expected to follow:
|
| 38 |
+
|
| 39 |
+
1. ORIENT: inspect code (`VIEW_CODE`)
|
| 40 |
+
2. DIAGNOSE: run tests and read failures (`RUN_TESTS`)
|
| 41 |
+
3. FIX: patch one region (`REPLACE_LINES`)
|
| 42 |
+
4. VERIFY: rerun tests (`RUN_TESTS`)
|
| 43 |
+
5. REPEAT: continue until all failures are resolved
|
| 44 |
+
6. SUBMIT: finalize only after tests pass
|
| 45 |
+
|
| 46 |
+
This structure is intentional: the environment trains planning, controlled
|
| 47 |
+
editing, and verification behavior, not just raw code generation.
|
| 48 |
+
|
| 49 |
+
## Task Tiers And Test Structure
|
| 50 |
+
|
| 51 |
+
Tasks are organized in `tasks.py` into three tiers.
|
| 52 |
+
|
| 53 |
+
- Easy: 4 tasks, each with 4 unit tests.
|
| 54 |
+
Focus: basic operators, indexing, and simple string/array logic.
|
| 55 |
+
- Medium: 6 tasks, each with 4 unit tests.
|
| 56 |
+
Focus: recursive behavior, branching correctness, and text normalization edge cases.
|
| 57 |
+
- Hard: 6 tasks, each with 3-4 unit tests.
|
| 58 |
+
Focus: data-structure invariants, eviction/promotion logic, bracket mapping, and interval merging edge behavior.
|
| 59 |
+
|
| 60 |
+
Every task follows the same schema:
|
| 61 |
+
- `name`, `description`, `difficulty`, `bug_type`
|
| 62 |
+
- `code`: buggy implementation (line list)
|
| 63 |
+
- `solution`: reference implementation
|
| 64 |
+
- `tests`: callable assertions executed in the sandbox
|
| 65 |
+
|
| 66 |
+
This gives consistent training signals while scaling complexity across tiers.
|
| 67 |
+
|
| 68 |
## Environment Files
|
| 69 |
|
| 70 |
- `models.py`: action/observation schemas
|
__init__.py
CHANGED
|
@@ -1,18 +1,12 @@
|
|
| 1 |
-
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
from .client import MyEnv, SWEGymEnv
|
| 10 |
from .models import CodeAction, CodeObservation, TestResult
|
| 11 |
|
| 12 |
__all__ = [
|
| 13 |
"CodeAction",
|
| 14 |
"CodeObservation",
|
| 15 |
"TestResult",
|
| 16 |
-
"
|
| 17 |
"MyEnv",
|
| 18 |
]
|
|
|
|
| 1 |
+
"""TraceFix-RL OpenEnv package."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
from .client import MyEnv, TraceFixRLEnv
|
|
|
|
|
|
|
| 4 |
from .models import CodeAction, CodeObservation, TestResult
|
| 5 |
|
| 6 |
__all__ = [
|
| 7 |
"CodeAction",
|
| 8 |
"CodeObservation",
|
| 9 |
"TestResult",
|
| 10 |
+
"TraceFixRLEnv",
|
| 11 |
"MyEnv",
|
| 12 |
]
|
client.py
CHANGED
|
@@ -1,10 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""Client for the SWE-Gym OpenEnv environment."""
|
| 8 |
|
| 9 |
from typing import Dict
|
| 10 |
|
|
@@ -18,57 +12,15 @@ except ImportError:
|
|
| 18 |
from models import CodeAction, CodeObservation, TestResult
|
| 19 |
|
| 20 |
|
| 21 |
-
class
|
| 22 |
EnvClient[CodeAction, CodeObservation, State]
|
| 23 |
):
|
| 24 |
-
"""
|
| 25 |
-
Client for the SWE-Gym environment.
|
| 26 |
-
|
| 27 |
-
This client maintains a persistent WebSocket connection to the environment server,
|
| 28 |
-
enabling efficient multi-step interactions with lower latency.
|
| 29 |
-
Each client instance has its own dedicated environment session on the server.
|
| 30 |
-
|
| 31 |
-
Example:
|
| 32 |
-
>>> # Connect to a running server
|
| 33 |
-
>>> with SWEGymEnv(base_url="http://localhost:7860") as client:
|
| 34 |
-
... result = client.reset()
|
| 35 |
-
... print(result.observation.echoed_message)
|
| 36 |
-
...
|
| 37 |
-
... result = client.step(MyAction(message="Hello!"))
|
| 38 |
-
... print(result.observation.echoed_message)
|
| 39 |
-
|
| 40 |
-
Example with Docker:
|
| 41 |
-
>>> # Automatically start container and connect
|
| 42 |
-
>>> client = SWEGymEnv.from_docker_image("swe-gym:latest")
|
| 43 |
-
>>> try:
|
| 44 |
-
... result = client.reset()
|
| 45 |
-
... result = client.step(MyAction(message="Test"))
|
| 46 |
-
... finally:
|
| 47 |
-
... client.close()
|
| 48 |
-
"""
|
| 49 |
|
| 50 |
def _step_payload(self, action: CodeAction) -> Dict:
|
| 51 |
-
"""
|
| 52 |
-
Convert MyAction to JSON payload for step message.
|
| 53 |
-
|
| 54 |
-
Args:
|
| 55 |
-
action: MyAction instance
|
| 56 |
-
|
| 57 |
-
Returns:
|
| 58 |
-
Dictionary representation suitable for JSON encoding
|
| 59 |
-
"""
|
| 60 |
return action.model_dump(exclude_none=True)
|
| 61 |
|
| 62 |
def _parse_result(self, payload: Dict) -> StepResult[CodeObservation]:
|
| 63 |
-
"""
|
| 64 |
-
Parse server response into StepResult[CodeObservation].
|
| 65 |
-
|
| 66 |
-
Args:
|
| 67 |
-
payload: JSON response data from server
|
| 68 |
-
|
| 69 |
-
Returns:
|
| 70 |
-
StepResult with MyObservation
|
| 71 |
-
"""
|
| 72 |
obs_data = payload.get("observation", {})
|
| 73 |
observation = CodeObservation(
|
| 74 |
code_lines=obs_data.get("code_lines", []),
|
|
@@ -94,20 +46,10 @@ class SWEGymEnv(
|
|
| 94 |
)
|
| 95 |
|
| 96 |
def _parse_state(self, payload: Dict) -> State:
|
| 97 |
-
"""
|
| 98 |
-
Parse server response into State object.
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
payload: JSON response from state request
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
State object with episode_id and step_count
|
| 105 |
-
"""
|
| 106 |
return State(
|
| 107 |
episode_id=payload.get("episode_id"),
|
| 108 |
step_count=payload.get("step_count", 0),
|
| 109 |
)
|
| 110 |
|
| 111 |
|
| 112 |
-
|
| 113 |
-
MyEnv = SWEGymEnv
|
|
|
|
| 1 |
+
"""Client for TraceFix-RL."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from typing import Dict
|
| 4 |
|
|
|
|
| 12 |
from models import CodeAction, CodeObservation, TestResult
|
| 13 |
|
| 14 |
|
| 15 |
+
class TraceFixRLEnv(
|
| 16 |
EnvClient[CodeAction, CodeObservation, State]
|
| 17 |
):
|
| 18 |
+
"""Typed OpenEnv client."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def _step_payload(self, action: CodeAction) -> Dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
return action.model_dump(exclude_none=True)
|
| 22 |
|
| 23 |
def _parse_result(self, payload: Dict) -> StepResult[CodeObservation]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
obs_data = payload.get("observation", {})
|
| 25 |
observation = CodeObservation(
|
| 26 |
code_lines=obs_data.get("code_lines", []),
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
def _parse_state(self, payload: Dict) -> State:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
return State(
|
| 50 |
episode_id=payload.get("episode_id"),
|
| 51 |
step_count=payload.get("step_count", 0),
|
| 52 |
)
|
| 53 |
|
| 54 |
|
| 55 |
+
MyEnv = TraceFixRLEnv
|
|
|
context.py
CHANGED
|
@@ -1,29 +1,11 @@
|
|
| 1 |
-
"""
|
| 2 |
-
context.py — Layered Context Compaction
|
| 3 |
-
=========================================
|
| 4 |
-
|
| 5 |
-
PRINCIPLE 10 — Layered Context Compaction
|
| 6 |
-
For large files, returning the full source on every observation would rapidly
|
| 7 |
-
fill the agent's context window, leaving no room for reasoning.
|
| 8 |
-
|
| 9 |
-
Instead we return a *localized* view: a ±WINDOW_LINES slice of the code
|
| 10 |
-
centred on the last line that was edited. This gives the agent exactly the
|
| 11 |
-
context it needs — the neighbourhood of its most recent change — without
|
| 12 |
-
flooding the context with unrelated code.
|
| 13 |
-
|
| 14 |
-
This module is intentionally pure (no environment state dependencies) so
|
| 15 |
-
it can be unit-tested independently and reused across environment versions.
|
| 16 |
-
"""
|
| 17 |
|
| 18 |
from __future__ import annotations
|
| 19 |
|
| 20 |
from typing import List, Optional
|
| 21 |
|
| 22 |
-
# How many lines above and below the anchor to include
|
| 23 |
WINDOW_LINES: int = 10
|
| 24 |
|
| 25 |
-
# Maximum characters for the localized context block
|
| 26 |
-
# (Principle 9: all outputs must be bounded)
|
| 27 |
MAX_CONTEXT_CHARS: int = 2_000
|
| 28 |
|
| 29 |
|
|
@@ -32,53 +14,19 @@ def get_localized_context(
|
|
| 32 |
anchor_line: Optional[int],
|
| 33 |
window: int = WINDOW_LINES,
|
| 34 |
) -> str:
|
| 35 |
-
"""
|
| 36 |
-
Return a ±`window`-line slice of `code_lines` centred on `anchor_line`.
|
| 37 |
-
|
| 38 |
-
Parameters
|
| 39 |
-
----------
|
| 40 |
-
code_lines : Full list of source lines (0-indexed internally).
|
| 41 |
-
anchor_line : The 1-indexed line number of the most recent edit.
|
| 42 |
-
If None (no edits yet) returns an empty string.
|
| 43 |
-
window : Number of lines to show above and below the anchor.
|
| 44 |
-
|
| 45 |
-
Returns
|
| 46 |
-
-------
|
| 47 |
-
A formatted string with line numbers, bounded to MAX_CONTEXT_CHARS,
|
| 48 |
-
annotated with the visible range and an anchor marker (▶).
|
| 49 |
-
|
| 50 |
-
Example output
|
| 51 |
-
--------------
|
| 52 |
-
[Showing lines 3–13 of 20, anchor ▶ line 7]
|
| 53 |
-
3 | left, right = 0, len(arr)
|
| 54 |
-
4 | while left <= right:
|
| 55 |
-
5 | mid = (left + right) // 2
|
| 56 |
-
6 | if arr[mid] == target:
|
| 57 |
-
7 ▶ return mid ← last edit
|
| 58 |
-
8 | elif arr[mid] < target:
|
| 59 |
-
9 | left = mid + 1
|
| 60 |
-
10 | else:
|
| 61 |
-
11 | right = mid - 1
|
| 62 |
-
12 | return -1
|
| 63 |
-
"""
|
| 64 |
if anchor_line is None or not code_lines:
|
| 65 |
return ""
|
| 66 |
|
| 67 |
total = len(code_lines)
|
| 68 |
|
| 69 |
-
# Clamp anchor into valid range
|
| 70 |
anchor_0 = max(0, min(anchor_line - 1, total - 1))
|
| 71 |
-
|
| 72 |
-
# Compute slice bounds (inclusive on both ends, 0-indexed)
|
| 73 |
start_0 = max(0, anchor_0 - window)
|
| 74 |
end_0 = min(total - 1, anchor_0 + window)
|
| 75 |
-
|
| 76 |
-
# Build header
|
| 77 |
start_1 = start_0 + 1
|
| 78 |
end_1 = end_0 + 1
|
| 79 |
header = f"[Showing lines {start_1}–{end_1} of {total}, anchor ▶ line {anchor_line}]"
|
| 80 |
|
| 81 |
-
# Build body
|
| 82 |
body_lines = []
|
| 83 |
for i in range(start_0, end_0 + 1):
|
| 84 |
line_num = i + 1
|
|
@@ -87,8 +35,7 @@ def get_localized_context(
|
|
| 87 |
|
| 88 |
result = header + "\n" + "\n".join(body_lines)
|
| 89 |
|
| 90 |
-
# PRINCIPLE 9 — hard cap on output size
|
| 91 |
if len(result) > MAX_CONTEXT_CHARS:
|
| 92 |
result = result[:MAX_CONTEXT_CHARS] + "\n... [context truncated]"
|
| 93 |
|
| 94 |
-
return result
|
|
|
|
| 1 |
+
"""Localized context helpers for TraceFix-RL."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from typing import List, Optional
|
| 6 |
|
|
|
|
| 7 |
WINDOW_LINES: int = 10
|
| 8 |
|
|
|
|
|
|
|
| 9 |
MAX_CONTEXT_CHARS: int = 2_000
|
| 10 |
|
| 11 |
|
|
|
|
| 14 |
anchor_line: Optional[int],
|
| 15 |
window: int = WINDOW_LINES,
|
| 16 |
) -> str:
|
| 17 |
+
"""Return a bounded ±window slice around the latest edited line."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
if anchor_line is None or not code_lines:
|
| 19 |
return ""
|
| 20 |
|
| 21 |
total = len(code_lines)
|
| 22 |
|
|
|
|
| 23 |
anchor_0 = max(0, min(anchor_line - 1, total - 1))
|
|
|
|
|
|
|
| 24 |
start_0 = max(0, anchor_0 - window)
|
| 25 |
end_0 = min(total - 1, anchor_0 + window)
|
|
|
|
|
|
|
| 26 |
start_1 = start_0 + 1
|
| 27 |
end_1 = end_0 + 1
|
| 28 |
header = f"[Showing lines {start_1}–{end_1} of {total}, anchor ▶ line {anchor_line}]"
|
| 29 |
|
|
|
|
| 30 |
body_lines = []
|
| 31 |
for i in range(start_0, end_0 + 1):
|
| 32 |
line_num = i + 1
|
|
|
|
| 35 |
|
| 36 |
result = header + "\n" + "\n".join(body_lines)
|
| 37 |
|
|
|
|
| 38 |
if len(result) > MAX_CONTEXT_CHARS:
|
| 39 |
result = result[:MAX_CONTEXT_CHARS] + "\n... [context truncated]"
|
| 40 |
|
| 41 |
+
return result
|
environment.py
CHANGED
|
@@ -1,45 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
-
environment.py — Python Debugging Gym (Core RL Environment)
|
| 3 |
-
=============================================================
|
| 4 |
-
|
| 5 |
-
PRINCIPLE 1 — You Don't Design the Control Flow
|
| 6 |
-
The agent decides the sequence of actions. step() is a pure router:
|
| 7 |
-
it receives whatever action the agent chose (in whatever order),
|
| 8 |
-
processes it, and returns the new state. There is no forced sequence,
|
| 9 |
-
no "you must VIEW_CODE before RUN_TESTS" gate. The system prompt
|
| 10 |
-
explains what tools exist; the agent decides how to use them.
|
| 11 |
-
|
| 12 |
-
PRINCIPLE 5 — Cost-Per-Turn Reward Logic
|
| 13 |
-
Each call to step() costs R_STEP_COST = -0.01. This makes the episode
|
| 14 |
-
a multi-turn budget problem: the agent is rewarded for solving quickly.
|
| 15 |
-
An agent that solves in 4 steps scores ~0.14 more than one that takes
|
| 16 |
-
18 steps to reach the same solution.
|
| 17 |
-
|
| 18 |
-
PRINCIPLE 7 — The Prompt is Code
|
| 19 |
-
The string returned by reset() is the agent's complete operational
|
| 20 |
-
contract for the session. It states: the goal, the available actions
|
| 21 |
-
(with exact JSON examples), the reward structure, the current code,
|
| 22 |
-
and the expected termination condition. Ambiguity in this string
|
| 23 |
-
directly causes off-task behaviour.
|
| 24 |
-
|
| 25 |
-
PRINCIPLE 10 — Layered Context Compaction
|
| 26 |
-
_build_observation() tracks `_last_edited_line` and passes it to
|
| 27 |
-
context.get_localized_context() to produce a focused ±10-line view
|
| 28 |
-
after each write action. This prevents the observation from inflating
|
| 29 |
-
the agent's context window on large files.
|
| 30 |
-
|
| 31 |
-
Reward table (dense, non-sparse — every step emits a signal):
|
| 32 |
-
+1.00 SUBMIT and ALL tests pass → episode solved
|
| 33 |
-
+0.10 RUN_TESTS called → information-gathering rewarded
|
| 34 |
-
+0.05 Per test transitioning fail→pass on a RUN_TESTS or SUBMIT
|
| 35 |
-
-0.01 Every step taken → efficiency pressure (Principle 5)
|
| 36 |
-
-0.10 Syntax error detected → broken code penalised immediately
|
| 37 |
-
-0.10 UNDO_EDIT or RESET_TO_ORIGINAL → backtracking discouraged
|
| 38 |
-
-0.02 Invalid line range supplied → hallucination deterrent
|
| 39 |
-
-0.20 SUBMIT with tests still failing
|
| 40 |
-
|
| 41 |
-
Max episode length: 50 steps.
|
| 42 |
-
"""
|
| 43 |
|
| 44 |
from __future__ import annotations
|
| 45 |
|
|
@@ -59,33 +18,18 @@ except ImportError:
|
|
| 59 |
from tasks import ALL_TASKS, TASKS_BY_DIFFICULTY
|
| 60 |
|
| 61 |
|
| 62 |
-
# ---------------------------------------------------------------------------
|
| 63 |
-
# Reward constants
|
| 64 |
-
# ---------------------------------------------------------------------------
|
| 65 |
-
|
| 66 |
R_SUBMIT_ALL_PASS = +1.00
|
| 67 |
R_SUBMIT_FAIL = -0.20
|
| 68 |
R_SYNTAX_ERROR = -0.10
|
| 69 |
R_RUN_TESTS = +0.10
|
| 70 |
R_PER_NEW_PASS = +0.05
|
| 71 |
-
R_STEP_COST = -0.01
|
| 72 |
R_INVALID_LINE = -0.02
|
| 73 |
R_DESTRUCTIVE_PENALTY = -0.20
|
| 74 |
-
R_UNDO_RESET = -0.10
|
| 75 |
|
| 76 |
MAX_STEPS: int = 50
|
| 77 |
|
| 78 |
-
|
| 79 |
-
# ---------------------------------------------------------------------------
|
| 80 |
-
# System Prompt (PRINCIPLE 7 — The Prompt is Code)
|
| 81 |
-
# ---------------------------------------------------------------------------
|
| 82 |
-
# This string is the agent's entire operational contract.
|
| 83 |
-
# It must be:
|
| 84 |
-
# • Self-contained (no assumed context from training data)
|
| 85 |
-
# • Precise (exact JSON examples, not vague descriptions)
|
| 86 |
-
# • Non-directive about sequence (Principle 1: agent chooses order)
|
| 87 |
-
# • Complete (goal, tools, rewards, termination — nothing omitted)
|
| 88 |
-
|
| 89 |
_SYSTEM_PROMPT = """\
|
| 90 |
╔══════════════════════════════════════════════════════╗
|
| 91 |
║ PYTHON DEBUGGING GYM — EPISODE BRIEF ║
|
|
@@ -176,24 +120,10 @@ CURRENT CODE (this is the broken version — fix it)
|
|
| 176 |
"""
|
| 177 |
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
# ---------------------------------------------------------------------------
|
| 182 |
-
|
| 183 |
-
class PythonDebuggingGym:
|
| 184 |
-
"""
|
| 185 |
-
Gymnasium-compatible RL environment for Python debugging.
|
| 186 |
-
|
| 187 |
-
PRINCIPLE 1: step() is a stateless router — the agent chooses the
|
| 188 |
-
sequence. No internal gates, no forced ordering between actions.
|
| 189 |
-
|
| 190 |
-
Interface
|
| 191 |
-
---------
|
| 192 |
-
obs, system_prompt = env.reset()
|
| 193 |
-
obs, reward, done, info = env.step(action: CodeAction)
|
| 194 |
-
"""
|
| 195 |
|
| 196 |
-
metadata = {"name": "
|
| 197 |
|
| 198 |
def __init__(
|
| 199 |
self,
|
|
@@ -203,25 +133,21 @@ class PythonDebuggingGym:
|
|
| 203 |
self._task_index = task_index
|
| 204 |
self._rng = random.Random(seed)
|
| 205 |
|
| 206 |
-
# All mutable episode state lives here; reset() wipes every field.
|
| 207 |
self._code_lines: List[str] = []
|
| 208 |
self._task: Dict[str, Any] = {}
|
| 209 |
self._step_count: int = 0
|
| 210 |
self._prev_pass_count: int = 0
|
| 211 |
self._last_test_results: List[TestResult] = []
|
| 212 |
self._last_output: str = ""
|
| 213 |
-
self._last_edited_line: Optional[int] = None
|
| 214 |
self._episode_id: str = ""
|
| 215 |
self._done: bool = False
|
| 216 |
self._cumulative_reward: float = 0.0
|
| 217 |
-
self._accumulated_step_costs: float = 0.0
|
| 218 |
-
|
| 219 |
-
self.
|
| 220 |
-
self._edit_history: List[List[str]] = [] # stack of pre-edit snapshots
|
| 221 |
-
# Curriculum learning — persists across episodes, incremented externally
|
| 222 |
self.training_step: int = 0
|
| 223 |
|
| 224 |
-
# ── Curriculum task sampler ──────────────────────────────────────────────
|
| 225 |
|
| 226 |
def _sample_task(self, task_override=None) -> Dict[str, Any]:
|
| 227 |
"""
|
|
@@ -241,13 +167,11 @@ class PythonDebuggingGym:
|
|
| 241 |
if isinstance(task_override, dict):
|
| 242 |
return task_override
|
| 243 |
|
| 244 |
-
# Judge-safe default: no training_step set → random from all tasks
|
| 245 |
if self.training_step == 0:
|
| 246 |
if not ALL_TASKS:
|
| 247 |
raise RuntimeError("ALL_TASKS is empty — check tasks.py.")
|
| 248 |
return self._rng.choice(ALL_TASKS)
|
| 249 |
|
| 250 |
-
# Curriculum mode (trainer increments training_step between episodes)
|
| 251 |
if self.training_step < 1000:
|
| 252 |
bucket = "easy"
|
| 253 |
elif self.training_step < 5000:
|
|
@@ -257,7 +181,6 @@ class PythonDebuggingGym:
|
|
| 257 |
|
| 258 |
pool = TASKS_BY_DIFFICULTY.get(bucket, [])
|
| 259 |
if not pool:
|
| 260 |
-
# Fallback: any non-empty bucket rather than crashing
|
| 261 |
for b in ("easy", "medium", "hard"):
|
| 262 |
pool = TASKS_BY_DIFFICULTY.get(b, [])
|
| 263 |
if pool:
|
|
@@ -267,7 +190,6 @@ class PythonDebuggingGym:
|
|
| 267 |
|
| 268 |
return self._rng.choice(pool)
|
| 269 |
|
| 270 |
-
# ── reset() ─────────────────────────────────────────────────────────────
|
| 271 |
|
| 272 |
def reset(
|
| 273 |
self, *, task_index: Optional[int] = None
|
|
@@ -281,7 +203,6 @@ class PythonDebuggingGym:
|
|
| 281 |
"""
|
| 282 |
self._task = self._sample_task(task_index)
|
| 283 |
|
| 284 |
-
# ── Complete state wipe ──────────────────────────────────────────
|
| 285 |
self._code_lines = list(self._task["code"]) # deep copy — no alias
|
| 286 |
self._step_count = 0
|
| 287 |
self._prev_pass_count = 0
|
|
@@ -292,16 +213,13 @@ class PythonDebuggingGym:
|
|
| 292 |
self._done = False
|
| 293 |
self._cumulative_reward = 0.0
|
| 294 |
self._accumulated_step_costs = 0.0
|
| 295 |
-
# Mini-Git: seed pristine snapshot and clear history
|
| 296 |
self._original_code = list(self._task["code"]) # separate copy from _code_lines
|
| 297 |
self._edit_history = []
|
| 298 |
-
# Anti-Loop history
|
| 299 |
self._last_action: Optional[str] = None
|
| 300 |
self._consecutive_count: int = 0
|
| 301 |
|
| 302 |
obs = self._build_observation(reward=0.0)
|
| 303 |
|
| 304 |
-
# PRINCIPLE 7: build the operational contract string
|
| 305 |
system_prompt = _SYSTEM_PROMPT.format(
|
| 306 |
task_name = self._task["name"],
|
| 307 |
difficulty = self._task.get("difficulty", "unknown"),
|
|
@@ -312,7 +230,6 @@ class PythonDebuggingGym:
|
|
| 312 |
|
| 313 |
return obs, system_prompt
|
| 314 |
|
| 315 |
-
# ── step() ──────────────────────────────────────────────────────────────
|
| 316 |
|
| 317 |
def step(
|
| 318 |
self, action: CodeAction
|
|
@@ -337,7 +254,6 @@ class PythonDebuggingGym:
|
|
| 337 |
reward = R_STEP_COST # PRINCIPLE 5: cost-per-turn baseline
|
| 338 |
self._accumulated_step_costs += abs(R_STEP_COST) # Hackathon compliance
|
| 339 |
|
| 340 |
-
# ── Repetition Penalty (Anti-Loop) ───────────────────────────────
|
| 341 |
if action.action_type == self._last_action:
|
| 342 |
self._consecutive_count += 1
|
| 343 |
reward += -0.05 * self._consecutive_count
|
|
@@ -345,7 +261,6 @@ class PythonDebuggingGym:
|
|
| 345 |
self._consecutive_count = 0
|
| 346 |
self._last_action = action.action_type
|
| 347 |
|
| 348 |
-
# ── Route (PRINCIPLE 1: no forced sequence) ──────────────────────
|
| 349 |
atype = action.action_type
|
| 350 |
|
| 351 |
if atype == "VIEW_CODE":
|
|
@@ -369,12 +284,8 @@ class PythonDebuggingGym:
|
|
| 369 |
reward += self._act_submit()
|
| 370 |
self._done = True
|
| 371 |
|
| 372 |
-
# ── Max-steps termination ────────────────────────────────────────
|
| 373 |
if self._step_count >= MAX_STEPS and not self._done:
|
| 374 |
self._done = True
|
| 375 |
-
# Deterministic clamp — never trust the LLM to call SUBMIT.
|
| 376 |
-
# Evaluate the current code and produce a valid [0.0, 1.0] score
|
| 377 |
-
# regardless of how the episode ended.
|
| 378 |
_, results, syntax_err = run_code_with_tests(
|
| 379 |
source=self._source(),
|
| 380 |
test_callables=self._task["tests"],
|
|
@@ -398,15 +309,10 @@ class PythonDebuggingGym:
|
|
| 398 |
"step": self._step_count,
|
| 399 |
}
|
| 400 |
if self._done:
|
| 401 |
-
# PRINCIPLE: Ensure Hackathon score leak doesn't occur. It must be strictly [0.0, 1.0].
|
| 402 |
-
# During SUBMIT, reward might be negative if _act_submit returned 0.0 added to -0.01.
|
| 403 |
info["final_score"] = max(0.0, min(1.0, round(reward, 4)))
|
| 404 |
|
| 405 |
return obs, round(reward, 4), self._done, info
|
| 406 |
|
| 407 |
-
# ── Action handlers ─────────────────────────────────────────────────────
|
| 408 |
-
# Each returns the delta reward (R_STEP_COST already applied by step()).
|
| 409 |
-
# Handlers update self._last_output and self._last_edited_line as needed.
|
| 410 |
|
| 411 |
def _act_view_code(self) -> float:
|
| 412 |
self._last_output = (
|
|
@@ -416,7 +322,6 @@ class PythonDebuggingGym:
|
|
| 416 |
for i, line in enumerate(self._code_lines)
|
| 417 |
)
|
| 418 |
)
|
| 419 |
-
# VIEW_CODE does not change the code — localized_context stays where it was
|
| 420 |
return 0.0
|
| 421 |
|
| 422 |
def _act_run_tests(self) -> float:
|
|
@@ -447,12 +352,10 @@ class PythonDebuggingGym:
|
|
| 447 |
if new_code_block is None:
|
| 448 |
new_code_block = ""
|
| 449 |
|
| 450 |
-
# ── Guard: Destructive Action (Anti-Deletion) ─────────────────────
|
| 451 |
if len(new_code_block) == 0 and (end_line - start_line) > 5:
|
| 452 |
self._last_output = "Error: Cannot delete more than 5 lines at once."
|
| 453 |
return R_DESTRUCTIVE_PENALTY
|
| 454 |
|
| 455 |
-
# ── Guard: inverted range ─────────────────────────────────────────
|
| 456 |
if start_line > end_line:
|
| 457 |
self._last_output = (
|
| 458 |
f"Error: start_line ({start_line}) > end_line ({end_line}). "
|
|
@@ -460,7 +363,6 @@ class PythonDebuggingGym:
|
|
| 460 |
)
|
| 461 |
return R_INVALID_LINE
|
| 462 |
|
| 463 |
-
# ── Guard: out-of-bounds ──────────────────────────────────────────
|
| 464 |
if start_line < 1 or start_line > n:
|
| 465 |
self._last_output = (
|
| 466 |
f"Error: start_line {start_line} is out of range [1, {n}]. "
|
|
@@ -474,19 +376,14 @@ class PythonDebuggingGym:
|
|
| 474 |
)
|
| 475 |
return R_INVALID_LINE
|
| 476 |
|
| 477 |
-
# ── Slice assignment (PRINCIPLE 1: pure data transformation) ──────
|
| 478 |
start_idx = start_line - 1 # convert to 0-indexed
|
| 479 |
end_idx = end_line # exclusive upper bound for Python slice
|
| 480 |
|
| 481 |
-
# ── Mini-Git: snapshot BEFORE mutating (Phase 2) ─────────────────
|
| 482 |
self._edit_history.append(list(self._code_lines))
|
| 483 |
|
| 484 |
new_lines = new_code_block.split("\n")
|
| 485 |
self._code_lines[start_idx:end_idx] = new_lines
|
| 486 |
|
| 487 |
-
# ── Anchor context at END of new block (PRINCIPLE 10) ─────────────
|
| 488 |
-
# If the agent replaces lines 5–10 with 20 new lines, the anchor
|
| 489 |
-
# settles at start_line + len(new_lines) - 1, clamped to file length.
|
| 490 |
new_end = start_line + len(new_lines) - 1
|
| 491 |
self._last_edited_line = min(new_end, len(self._code_lines))
|
| 492 |
|
|
@@ -514,9 +411,6 @@ class PythonDebuggingGym:
|
|
| 514 |
if syntax_err:
|
| 515 |
self._last_output += "\n❌ SUBMIT rejected — syntax error in current code."
|
| 516 |
|
| 517 |
-
# ── Hackathon compliance: final score ∈ [0.0, 1.0] ───────────────
|
| 518 |
-
# raw = (tests_passed / total) - accumulated_step_costs
|
| 519 |
-
# Then clamped so the grader always receives a value in spec.
|
| 520 |
proportion = passes / total if total > 0 else 0.0
|
| 521 |
raw_score = proportion - self._accumulated_step_costs
|
| 522 |
final_score = max(0.0, min(1.0, raw_score))
|
|
@@ -559,7 +453,6 @@ class PythonDebuggingGym:
|
|
| 559 |
"Call VIEW_CODE to inspect the restored file."
|
| 560 |
)
|
| 561 |
|
| 562 |
-
# PRINCIPLE 10 desync fix: anchor is stale after rollback — wipe it.
|
| 563 |
self._last_edited_line = None
|
| 564 |
return R_UNDO_RESET
|
| 565 |
|
|
@@ -580,11 +473,9 @@ class PythonDebuggingGym:
|
|
| 580 |
"Call VIEW_CODE to inspect the file."
|
| 581 |
)
|
| 582 |
|
| 583 |
-
# PRINCIPLE 10 desync fix: context anchor is meaningless after full reset.
|
| 584 |
self._last_edited_line = None
|
| 585 |
return R_UNDO_RESET
|
| 586 |
|
| 587 |
-
# ── Helpers ─────────────────────────────────────────────────────────────
|
| 588 |
|
| 589 |
def _source(self) -> str:
|
| 590 |
return "\n".join(self._code_lines)
|
|
@@ -592,7 +483,6 @@ class PythonDebuggingGym:
|
|
| 592 |
def _build_observation(self, reward: float) -> CodeObservation:
|
| 593 |
syntax_valid, _ = check_syntax(self._source())
|
| 594 |
|
| 595 |
-
# PRINCIPLE 10: localized context — only ±10 lines around last edit
|
| 596 |
localized = get_localized_context(self._code_lines, self._last_edited_line)
|
| 597 |
|
| 598 |
return CodeObservation(
|
|
|
|
| 1 |
+
"""Core TraceFix-RL environment implementation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 18 |
from tasks import ALL_TASKS, TASKS_BY_DIFFICULTY
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
R_SUBMIT_ALL_PASS = +1.00
|
| 22 |
R_SUBMIT_FAIL = -0.20
|
| 23 |
R_SYNTAX_ERROR = -0.10
|
| 24 |
R_RUN_TESTS = +0.10
|
| 25 |
R_PER_NEW_PASS = +0.05
|
| 26 |
+
R_STEP_COST = -0.01
|
| 27 |
R_INVALID_LINE = -0.02
|
| 28 |
R_DESTRUCTIVE_PENALTY = -0.20
|
| 29 |
+
R_UNDO_RESET = -0.10
|
| 30 |
|
| 31 |
MAX_STEPS: int = 50
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
_SYSTEM_PROMPT = """\
|
| 34 |
╔══════════════════════════════════════════════════════╗
|
| 35 |
║ PYTHON DEBUGGING GYM — EPISODE BRIEF ║
|
|
|
|
| 120 |
"""
|
| 121 |
|
| 122 |
|
| 123 |
+
class TraceFixRLGym:
|
| 124 |
+
"""Gym-style environment with reset/step methods."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
+
metadata = {"name": "TraceFixRL-v1", "render_modes": []}
|
| 127 |
|
| 128 |
def __init__(
|
| 129 |
self,
|
|
|
|
| 133 |
self._task_index = task_index
|
| 134 |
self._rng = random.Random(seed)
|
| 135 |
|
|
|
|
| 136 |
self._code_lines: List[str] = []
|
| 137 |
self._task: Dict[str, Any] = {}
|
| 138 |
self._step_count: int = 0
|
| 139 |
self._prev_pass_count: int = 0
|
| 140 |
self._last_test_results: List[TestResult] = []
|
| 141 |
self._last_output: str = ""
|
| 142 |
+
self._last_edited_line: Optional[int] = None
|
| 143 |
self._episode_id: str = ""
|
| 144 |
self._done: bool = False
|
| 145 |
self._cumulative_reward: float = 0.0
|
| 146 |
+
self._accumulated_step_costs: float = 0.0
|
| 147 |
+
self._original_code: List[str] = []
|
| 148 |
+
self._edit_history: List[List[str]] = []
|
|
|
|
|
|
|
| 149 |
self.training_step: int = 0
|
| 150 |
|
|
|
|
| 151 |
|
| 152 |
def _sample_task(self, task_override=None) -> Dict[str, Any]:
|
| 153 |
"""
|
|
|
|
| 167 |
if isinstance(task_override, dict):
|
| 168 |
return task_override
|
| 169 |
|
|
|
|
| 170 |
if self.training_step == 0:
|
| 171 |
if not ALL_TASKS:
|
| 172 |
raise RuntimeError("ALL_TASKS is empty — check tasks.py.")
|
| 173 |
return self._rng.choice(ALL_TASKS)
|
| 174 |
|
|
|
|
| 175 |
if self.training_step < 1000:
|
| 176 |
bucket = "easy"
|
| 177 |
elif self.training_step < 5000:
|
|
|
|
| 181 |
|
| 182 |
pool = TASKS_BY_DIFFICULTY.get(bucket, [])
|
| 183 |
if not pool:
|
|
|
|
| 184 |
for b in ("easy", "medium", "hard"):
|
| 185 |
pool = TASKS_BY_DIFFICULTY.get(b, [])
|
| 186 |
if pool:
|
|
|
|
| 190 |
|
| 191 |
return self._rng.choice(pool)
|
| 192 |
|
|
|
|
| 193 |
|
| 194 |
def reset(
|
| 195 |
self, *, task_index: Optional[int] = None
|
|
|
|
| 203 |
"""
|
| 204 |
self._task = self._sample_task(task_index)
|
| 205 |
|
|
|
|
| 206 |
self._code_lines = list(self._task["code"]) # deep copy — no alias
|
| 207 |
self._step_count = 0
|
| 208 |
self._prev_pass_count = 0
|
|
|
|
| 213 |
self._done = False
|
| 214 |
self._cumulative_reward = 0.0
|
| 215 |
self._accumulated_step_costs = 0.0
|
|
|
|
| 216 |
self._original_code = list(self._task["code"]) # separate copy from _code_lines
|
| 217 |
self._edit_history = []
|
|
|
|
| 218 |
self._last_action: Optional[str] = None
|
| 219 |
self._consecutive_count: int = 0
|
| 220 |
|
| 221 |
obs = self._build_observation(reward=0.0)
|
| 222 |
|
|
|
|
| 223 |
system_prompt = _SYSTEM_PROMPT.format(
|
| 224 |
task_name = self._task["name"],
|
| 225 |
difficulty = self._task.get("difficulty", "unknown"),
|
|
|
|
| 230 |
|
| 231 |
return obs, system_prompt
|
| 232 |
|
|
|
|
| 233 |
|
| 234 |
def step(
|
| 235 |
self, action: CodeAction
|
|
|
|
| 254 |
reward = R_STEP_COST # PRINCIPLE 5: cost-per-turn baseline
|
| 255 |
self._accumulated_step_costs += abs(R_STEP_COST) # Hackathon compliance
|
| 256 |
|
|
|
|
| 257 |
if action.action_type == self._last_action:
|
| 258 |
self._consecutive_count += 1
|
| 259 |
reward += -0.05 * self._consecutive_count
|
|
|
|
| 261 |
self._consecutive_count = 0
|
| 262 |
self._last_action = action.action_type
|
| 263 |
|
|
|
|
| 264 |
atype = action.action_type
|
| 265 |
|
| 266 |
if atype == "VIEW_CODE":
|
|
|
|
| 284 |
reward += self._act_submit()
|
| 285 |
self._done = True
|
| 286 |
|
|
|
|
| 287 |
if self._step_count >= MAX_STEPS and not self._done:
|
| 288 |
self._done = True
|
|
|
|
|
|
|
|
|
|
| 289 |
_, results, syntax_err = run_code_with_tests(
|
| 290 |
source=self._source(),
|
| 291 |
test_callables=self._task["tests"],
|
|
|
|
| 309 |
"step": self._step_count,
|
| 310 |
}
|
| 311 |
if self._done:
|
|
|
|
|
|
|
| 312 |
info["final_score"] = max(0.0, min(1.0, round(reward, 4)))
|
| 313 |
|
| 314 |
return obs, round(reward, 4), self._done, info
|
| 315 |
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
def _act_view_code(self) -> float:
|
| 318 |
self._last_output = (
|
|
|
|
| 322 |
for i, line in enumerate(self._code_lines)
|
| 323 |
)
|
| 324 |
)
|
|
|
|
| 325 |
return 0.0
|
| 326 |
|
| 327 |
def _act_run_tests(self) -> float:
|
|
|
|
| 352 |
if new_code_block is None:
|
| 353 |
new_code_block = ""
|
| 354 |
|
|
|
|
| 355 |
if len(new_code_block) == 0 and (end_line - start_line) > 5:
|
| 356 |
self._last_output = "Error: Cannot delete more than 5 lines at once."
|
| 357 |
return R_DESTRUCTIVE_PENALTY
|
| 358 |
|
|
|
|
| 359 |
if start_line > end_line:
|
| 360 |
self._last_output = (
|
| 361 |
f"Error: start_line ({start_line}) > end_line ({end_line}). "
|
|
|
|
| 363 |
)
|
| 364 |
return R_INVALID_LINE
|
| 365 |
|
|
|
|
| 366 |
if start_line < 1 or start_line > n:
|
| 367 |
self._last_output = (
|
| 368 |
f"Error: start_line {start_line} is out of range [1, {n}]. "
|
|
|
|
| 376 |
)
|
| 377 |
return R_INVALID_LINE
|
| 378 |
|
|
|
|
| 379 |
start_idx = start_line - 1 # convert to 0-indexed
|
| 380 |
end_idx = end_line # exclusive upper bound for Python slice
|
| 381 |
|
|
|
|
| 382 |
self._edit_history.append(list(self._code_lines))
|
| 383 |
|
| 384 |
new_lines = new_code_block.split("\n")
|
| 385 |
self._code_lines[start_idx:end_idx] = new_lines
|
| 386 |
|
|
|
|
|
|
|
|
|
|
| 387 |
new_end = start_line + len(new_lines) - 1
|
| 388 |
self._last_edited_line = min(new_end, len(self._code_lines))
|
| 389 |
|
|
|
|
| 411 |
if syntax_err:
|
| 412 |
self._last_output += "\n❌ SUBMIT rejected — syntax error in current code."
|
| 413 |
|
|
|
|
|
|
|
|
|
|
| 414 |
proportion = passes / total if total > 0 else 0.0
|
| 415 |
raw_score = proportion - self._accumulated_step_costs
|
| 416 |
final_score = max(0.0, min(1.0, raw_score))
|
|
|
|
| 453 |
"Call VIEW_CODE to inspect the restored file."
|
| 454 |
)
|
| 455 |
|
|
|
|
| 456 |
self._last_edited_line = None
|
| 457 |
return R_UNDO_RESET
|
| 458 |
|
|
|
|
| 473 |
"Call VIEW_CODE to inspect the file."
|
| 474 |
)
|
| 475 |
|
|
|
|
| 476 |
self._last_edited_line = None
|
| 477 |
return R_UNDO_RESET
|
| 478 |
|
|
|
|
| 479 |
|
| 480 |
def _source(self) -> str:
|
| 481 |
return "\n".join(self._code_lines)
|
|
|
|
| 483 |
def _build_observation(self, reward: float) -> CodeObservation:
|
| 484 |
syntax_valid, _ = check_syntax(self._source())
|
| 485 |
|
|
|
|
| 486 |
localized = get_localized_context(self._code_lines, self._last_edited_line)
|
| 487 |
|
| 488 |
return CodeObservation(
|
inference.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
Inference script for
|
| 3 |
|
| 4 |
Mandatory env vars expected in deployment config:
|
| 5 |
API_BASE_URL
|
|
@@ -24,9 +24,9 @@ from typing import Any
|
|
| 24 |
from openai import OpenAI
|
| 25 |
|
| 26 |
try:
|
| 27 |
-
from
|
| 28 |
except ImportError:
|
| 29 |
-
from client import
|
| 30 |
from models import CodeAction
|
| 31 |
|
| 32 |
|
|
@@ -36,8 +36,8 @@ HF_TOKEN = os.getenv("HF_TOKEN", "")
|
|
| 36 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "")
|
| 37 |
|
| 38 |
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
|
| 39 |
-
TASK_NAME = os.getenv("TASK_NAME", "
|
| 40 |
-
BENCHMARK = os.getenv("BENCHMARK", "
|
| 41 |
MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
|
| 42 |
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
|
| 43 |
|
|
@@ -170,7 +170,7 @@ def _compute_score(step_result: Any, rewards: list[float]) -> float:
|
|
| 170 |
async def main() -> None:
|
| 171 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 172 |
|
| 173 |
-
env:
|
| 174 |
rewards: list[float] = []
|
| 175 |
history: list[str] = []
|
| 176 |
steps_taken = 0
|
|
@@ -180,9 +180,9 @@ async def main() -> None:
|
|
| 180 |
|
| 181 |
try:
|
| 182 |
if LOCAL_IMAGE_NAME:
|
| 183 |
-
env = await
|
| 184 |
else:
|
| 185 |
-
env =
|
| 186 |
|
| 187 |
result = await env.reset()
|
| 188 |
task_name = result.observation.info.get("task_name") or TASK_NAME
|
|
|
|
| 1 |
"""
|
| 2 |
+
Inference script for TraceFix-RL.
|
| 3 |
|
| 4 |
Mandatory env vars expected in deployment config:
|
| 5 |
API_BASE_URL
|
|
|
|
| 24 |
from openai import OpenAI
|
| 25 |
|
| 26 |
try:
|
| 27 |
+
from tracefix_rl import CodeAction, TraceFixRLEnv
|
| 28 |
except ImportError:
|
| 29 |
+
from client import TraceFixRLEnv
|
| 30 |
from models import CodeAction
|
| 31 |
|
| 32 |
|
|
|
|
| 36 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "")
|
| 37 |
|
| 38 |
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
|
| 39 |
+
TASK_NAME = os.getenv("TASK_NAME", "tracefix_rl")
|
| 40 |
+
BENCHMARK = os.getenv("BENCHMARK", "tracefix_rl")
|
| 41 |
MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
|
| 42 |
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
|
| 43 |
|
|
|
|
| 170 |
async def main() -> None:
|
| 171 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 172 |
|
| 173 |
+
env: TraceFixRLEnv | None = None
|
| 174 |
rewards: list[float] = []
|
| 175 |
history: list[str] = []
|
| 176 |
steps_taken = 0
|
|
|
|
| 180 |
|
| 181 |
try:
|
| 182 |
if LOCAL_IMAGE_NAME:
|
| 183 |
+
env = await TraceFixRLEnv.from_docker_image(LOCAL_IMAGE_NAME)
|
| 184 |
else:
|
| 185 |
+
env = TraceFixRLEnv(base_url=ENV_BASE_URL)
|
| 186 |
|
| 187 |
result = await env.reset()
|
| 188 |
task_name = result.observation.info.get("task_name") or TASK_NAME
|
models.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Pydantic schema layer for
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""Pydantic schema layer for TraceFix-RL."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
openenv.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
spec_version: 1
|
| 2 |
-
name:
|
| 3 |
type: space
|
| 4 |
runtime: fastapi
|
| 5 |
app: server.app:app
|
|
|
|
| 1 |
spec_version: 1
|
| 2 |
+
name: tracefix_rl
|
| 3 |
type: space
|
| 4 |
runtime: fastapi
|
| 5 |
app: server.app:app
|
pyproject.toml
CHANGED
|
@@ -9,9 +9,9 @@ requires = ["setuptools>=45", "wheel"]
|
|
| 9 |
build-backend = "setuptools.build_meta"
|
| 10 |
|
| 11 |
[project]
|
| 12 |
-
name = "openenv-
|
| 13 |
version = "0.1.0"
|
| 14 |
-
description = "
|
| 15 |
requires-python = ">=3.10"
|
| 16 |
dependencies = [
|
| 17 |
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
|
@@ -38,10 +38,10 @@ dev = [
|
|
| 38 |
|
| 39 |
[project.scripts]
|
| 40 |
# Server entry point - enables running via: uv run --project . server
|
| 41 |
-
# or: python -m
|
| 42 |
-
server = "
|
| 43 |
|
| 44 |
[tool.setuptools]
|
| 45 |
include-package-data = true
|
| 46 |
-
packages = ["
|
| 47 |
-
package-dir = { "
|
|
|
|
| 9 |
build-backend = "setuptools.build_meta"
|
| 10 |
|
| 11 |
[project]
|
| 12 |
+
name = "openenv-tracefix-rl"
|
| 13 |
version = "0.1.0"
|
| 14 |
+
description = "TraceFix-RL environment for OpenEnv"
|
| 15 |
requires-python = ">=3.10"
|
| 16 |
dependencies = [
|
| 17 |
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
|
|
|
| 38 |
|
| 39 |
[project.scripts]
|
| 40 |
# Server entry point - enables running via: uv run --project . server
|
| 41 |
+
# or: python -m tracefix_rl.server.app
|
| 42 |
+
server = "tracefix_rl.server.app:main"
|
| 43 |
|
| 44 |
[tool.setuptools]
|
| 45 |
include-package-data = true
|
| 46 |
+
packages = ["tracefix_rl", "tracefix_rl.server"]
|
| 47 |
+
package-dir = { "tracefix_rl" = ".", "tracefix_rl.server" = "server" }
|
sandbox.py
CHANGED
|
@@ -44,17 +44,11 @@ except ImportError:
|
|
| 44 |
from models import TestResult
|
| 45 |
|
| 46 |
|
| 47 |
-
# ---------------------------------------------------------------------------
|
| 48 |
-
# Constants
|
| 49 |
-
# ---------------------------------------------------------------------------
|
| 50 |
|
| 51 |
EXEC_TIMEOUT_SECONDS: int = 5 # Hard wall-clock kill limit (Principle 8)
|
| 52 |
MAX_OUTPUT_CHARS: int = 1_000 # Tail-truncate limit (Principle 9)
|
| 53 |
|
| 54 |
|
| 55 |
-
# ---------------------------------------------------------------------------
|
| 56 |
-
# Restricted builtins (Principle 8)
|
| 57 |
-
# ---------------------------------------------------------------------------
|
| 58 |
|
| 59 |
def _make_safe_stub(name: str) -> Callable:
|
| 60 |
"""Return a callable that raises RuntimeError — used to block dangerous builtins."""
|
|
@@ -68,26 +62,19 @@ def _make_safe_stub(name: str) -> Callable:
|
|
| 68 |
return _stub
|
| 69 |
|
| 70 |
|
| 71 |
-
# Whitelist: safe builtins the agent's code is allowed to use.
|
| 72 |
-
# Everything not in this dict is blocked.
|
| 73 |
_SAFE_BUILTINS: Dict[str, Any] = {
|
| 74 |
-
# Type constructors
|
| 75 |
"int": int, "float": float, "str": str, "bool": bool,
|
| 76 |
"list": list, "dict": dict, "set": set, "tuple": tuple,
|
| 77 |
"bytes": bytes, "bytearray": bytearray, "frozenset": frozenset,
|
| 78 |
"complex": complex,
|
| 79 |
-
# Inspection / iteration
|
| 80 |
"len": len, "range": range, "enumerate": enumerate, "zip": zip,
|
| 81 |
"map": map, "filter": filter, "reversed": reversed, "sorted": sorted,
|
| 82 |
"iter": iter, "next": next, "sum": sum, "min": min, "max": max,
|
| 83 |
"abs": abs, "round": round, "divmod": divmod, "pow": pow,
|
| 84 |
-
# Introspection
|
| 85 |
"isinstance": isinstance, "issubclass": issubclass, "type": type,
|
| 86 |
"hasattr": hasattr, "getattr": getattr, "setattr": setattr,
|
| 87 |
"callable": callable, "repr": repr, "hash": hash, "id": id,
|
| 88 |
-
# I/O (stdout only — stderr is captured separately)
|
| 89 |
"print": print,
|
| 90 |
-
# Exceptions & control
|
| 91 |
"Exception": Exception, "ValueError": ValueError, "TypeError": TypeError,
|
| 92 |
"KeyError": KeyError, "IndexError": IndexError, "AttributeError": AttributeError,
|
| 93 |
"StopIteration": StopIteration, "RuntimeError": RuntimeError,
|
|
@@ -96,13 +83,11 @@ _SAFE_BUILTINS: Dict[str, Any] = {
|
|
| 96 |
"RecursionError": RecursionError, "MemoryError": MemoryError,
|
| 97 |
"KeyboardInterrupt": KeyboardInterrupt,
|
| 98 |
"BaseException": BaseException,
|
| 99 |
-
# Functional
|
| 100 |
"any": any, "all": all,
|
| 101 |
"chr": chr, "ord": ord, "hex": hex, "oct": oct, "bin": bin,
|
| 102 |
"format": format,
|
| 103 |
"object": object, "property": property, "staticmethod": staticmethod,
|
| 104 |
"classmethod": classmethod, "super": super,
|
| 105 |
-
# Blocked with stubs (Principle 8)
|
| 106 |
"open": _make_safe_stub("open"),
|
| 107 |
"__import__": _make_safe_stub("__import__"),
|
| 108 |
"eval": _make_safe_stub("eval"),
|
|
@@ -119,9 +104,6 @@ _SAFE_BUILTINS: Dict[str, Any] = {
|
|
| 119 |
}
|
| 120 |
|
| 121 |
|
| 122 |
-
# ---------------------------------------------------------------------------
|
| 123 |
-
# Output truncation (Principle 9)
|
| 124 |
-
# ---------------------------------------------------------------------------
|
| 125 |
|
| 126 |
def _tail_truncate(s: str, limit: int = MAX_OUTPUT_CHARS) -> str:
|
| 127 |
"""
|
|
@@ -138,9 +120,6 @@ def _tail_truncate(s: str, limit: int = MAX_OUTPUT_CHARS) -> str:
|
|
| 138 |
return f"[...truncated {dropped} chars...]\n" + s[-limit:]
|
| 139 |
|
| 140 |
|
| 141 |
-
# ---------------------------------------------------------------------------
|
| 142 |
-
# Worker (runs in isolated child process)
|
| 143 |
-
# ---------------------------------------------------------------------------
|
| 144 |
|
| 145 |
def _worker(
|
| 146 |
source: str,
|
|
@@ -162,41 +141,29 @@ def _worker(
|
|
| 162 |
fn_name = "<unknown>"
|
| 163 |
|
| 164 |
try:
|
| 165 |
-
# ── Phase 1: Syntax check ─────────────────────────────────────────
|
| 166 |
-
# Compile before exec() so SyntaxError is caught cleanly.
|
| 167 |
try:
|
| 168 |
code_obj = compile(source, "<agent_code>", "exec")
|
| 169 |
except SyntaxError as exc:
|
| 170 |
had_syntax_error = True
|
| 171 |
-
# Restore streams before writing the error
|
| 172 |
sys.stdout, sys.stderr = old_stdout, old_stderr
|
| 173 |
err = f"SyntaxError at line {exc.lineno}: {exc.msg}\n >> {exc.text or ''}"
|
| 174 |
result_queue.put((_tail_truncate(err), [], True))
|
| 175 |
return
|
| 176 |
|
| 177 |
-
# ── Phase 2: Execute agent code into a sandboxed namespace ───────
|
| 178 |
-
# Use full __builtins__ to prevent __build_class__ errors for class-based tasks.
|
| 179 |
namespace: Dict[str, Any] = {"__builtins__": __builtins__}
|
| 180 |
try:
|
| 181 |
exec(code_obj, namespace) # noqa: S102
|
| 182 |
except Exception: # noqa: BLE001
|
| 183 |
-
# PRINCIPLE 2: execution crash is data, not a crash
|
| 184 |
tb = traceback.format_exc()
|
| 185 |
sys.stdout, sys.stderr = old_stdout, old_stderr
|
| 186 |
result_queue.put((_tail_truncate(buf.getvalue() + "\n" + tb), [], False))
|
| 187 |
return
|
| 188 |
|
| 189 |
-
# ── Phase 3: Run each test function ──────────────────────────────
|
| 190 |
-
# PRINCIPLE 2: each test is isolated inside its own try-except so a
|
| 191 |
-
# crash in test N does not prevent tests N+1..M from running.
|
| 192 |
for test_src in test_sources:
|
| 193 |
fn_name = "<unknown>"
|
| 194 |
try:
|
| 195 |
-
# Inject the test function into the existing namespace so it
|
| 196 |
-
# can access the agent's defined symbols.
|
| 197 |
exec(test_src, namespace) # noqa: S102
|
| 198 |
|
| 199 |
-
# Extract the last `def` name from the test source.
|
| 200 |
fn_name = [
|
| 201 |
ln.split("(")[0].replace("def ", "").strip()
|
| 202 |
for ln in test_src.splitlines()
|
|
@@ -207,7 +174,6 @@ def _worker(
|
|
| 207 |
test_results.append({"test_name": fn_name, "passed": True})
|
| 208 |
|
| 209 |
except AssertionError as exc:
|
| 210 |
-
# PRINCIPLE 2: assertion failure is structured data
|
| 211 |
test_results.append({
|
| 212 |
"test_name": fn_name,
|
| 213 |
"passed": False,
|
|
@@ -216,7 +182,6 @@ def _worker(
|
|
| 216 |
),
|
| 217 |
})
|
| 218 |
except Exception: # noqa: BLE001
|
| 219 |
-
# PRINCIPLE 2: all other exceptions also become structured data
|
| 220 |
test_results.append({
|
| 221 |
"test_name": fn_name,
|
| 222 |
"passed": False,
|
|
@@ -224,7 +189,6 @@ def _worker(
|
|
| 224 |
})
|
| 225 |
|
| 226 |
except Exception: # noqa: BLE001
|
| 227 |
-
# Catch-all for any unexpected failure in the harness itself
|
| 228 |
traceback.print_exc(file=buf)
|
| 229 |
finally:
|
| 230 |
sys.stdout, sys.stderr = old_stdout, old_stderr
|
|
@@ -233,9 +197,6 @@ def _worker(
|
|
| 233 |
result_queue.put((captured, test_results, had_syntax_error))
|
| 234 |
|
| 235 |
|
| 236 |
-
# ---------------------------------------------------------------------------
|
| 237 |
-
# Public API
|
| 238 |
-
# ---------------------------------------------------------------------------
|
| 239 |
|
| 240 |
def check_syntax(source: str) -> Tuple[bool, str]:
|
| 241 |
"""
|
|
@@ -271,7 +232,6 @@ def run_code_with_tests(
|
|
| 271 |
-------
|
| 272 |
(output_str, test_results, had_syntax_error)
|
| 273 |
"""
|
| 274 |
-
# Serialise callables → source strings (required for pickling across processes)
|
| 275 |
test_sources = [
|
| 276 |
textwrap.dedent(inspect.getsource(fn))
|
| 277 |
for fn in test_callables
|
|
@@ -286,7 +246,6 @@ def run_code_with_tests(
|
|
| 286 |
proc.start()
|
| 287 |
proc.join(timeout)
|
| 288 |
|
| 289 |
-
# PRINCIPLE 8 — hard kill (SIGTERM first, SIGKILL if still alive)
|
| 290 |
if proc.is_alive():
|
| 291 |
proc.terminate()
|
| 292 |
proc.join(2) # Give it 2s to handle SIGTERM gracefully
|
|
|
|
| 44 |
from models import TestResult
|
| 45 |
|
| 46 |
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
EXEC_TIMEOUT_SECONDS: int = 5 # Hard wall-clock kill limit (Principle 8)
|
| 49 |
MAX_OUTPUT_CHARS: int = 1_000 # Tail-truncate limit (Principle 9)
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
def _make_safe_stub(name: str) -> Callable:
|
| 54 |
"""Return a callable that raises RuntimeError — used to block dangerous builtins."""
|
|
|
|
| 62 |
return _stub
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
| 65 |
_SAFE_BUILTINS: Dict[str, Any] = {
|
|
|
|
| 66 |
"int": int, "float": float, "str": str, "bool": bool,
|
| 67 |
"list": list, "dict": dict, "set": set, "tuple": tuple,
|
| 68 |
"bytes": bytes, "bytearray": bytearray, "frozenset": frozenset,
|
| 69 |
"complex": complex,
|
|
|
|
| 70 |
"len": len, "range": range, "enumerate": enumerate, "zip": zip,
|
| 71 |
"map": map, "filter": filter, "reversed": reversed, "sorted": sorted,
|
| 72 |
"iter": iter, "next": next, "sum": sum, "min": min, "max": max,
|
| 73 |
"abs": abs, "round": round, "divmod": divmod, "pow": pow,
|
|
|
|
| 74 |
"isinstance": isinstance, "issubclass": issubclass, "type": type,
|
| 75 |
"hasattr": hasattr, "getattr": getattr, "setattr": setattr,
|
| 76 |
"callable": callable, "repr": repr, "hash": hash, "id": id,
|
|
|
|
| 77 |
"print": print,
|
|
|
|
| 78 |
"Exception": Exception, "ValueError": ValueError, "TypeError": TypeError,
|
| 79 |
"KeyError": KeyError, "IndexError": IndexError, "AttributeError": AttributeError,
|
| 80 |
"StopIteration": StopIteration, "RuntimeError": RuntimeError,
|
|
|
|
| 83 |
"RecursionError": RecursionError, "MemoryError": MemoryError,
|
| 84 |
"KeyboardInterrupt": KeyboardInterrupt,
|
| 85 |
"BaseException": BaseException,
|
|
|
|
| 86 |
"any": any, "all": all,
|
| 87 |
"chr": chr, "ord": ord, "hex": hex, "oct": oct, "bin": bin,
|
| 88 |
"format": format,
|
| 89 |
"object": object, "property": property, "staticmethod": staticmethod,
|
| 90 |
"classmethod": classmethod, "super": super,
|
|
|
|
| 91 |
"open": _make_safe_stub("open"),
|
| 92 |
"__import__": _make_safe_stub("__import__"),
|
| 93 |
"eval": _make_safe_stub("eval"),
|
|
|
|
| 104 |
}
|
| 105 |
|
| 106 |
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
def _tail_truncate(s: str, limit: int = MAX_OUTPUT_CHARS) -> str:
|
| 109 |
"""
|
|
|
|
| 120 |
return f"[...truncated {dropped} chars...]\n" + s[-limit:]
|
| 121 |
|
| 122 |
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
def _worker(
|
| 125 |
source: str,
|
|
|
|
| 141 |
fn_name = "<unknown>"
|
| 142 |
|
| 143 |
try:
|
|
|
|
|
|
|
| 144 |
try:
|
| 145 |
code_obj = compile(source, "<agent_code>", "exec")
|
| 146 |
except SyntaxError as exc:
|
| 147 |
had_syntax_error = True
|
|
|
|
| 148 |
sys.stdout, sys.stderr = old_stdout, old_stderr
|
| 149 |
err = f"SyntaxError at line {exc.lineno}: {exc.msg}\n >> {exc.text or ''}"
|
| 150 |
result_queue.put((_tail_truncate(err), [], True))
|
| 151 |
return
|
| 152 |
|
|
|
|
|
|
|
| 153 |
namespace: Dict[str, Any] = {"__builtins__": __builtins__}
|
| 154 |
try:
|
| 155 |
exec(code_obj, namespace) # noqa: S102
|
| 156 |
except Exception: # noqa: BLE001
|
|
|
|
| 157 |
tb = traceback.format_exc()
|
| 158 |
sys.stdout, sys.stderr = old_stdout, old_stderr
|
| 159 |
result_queue.put((_tail_truncate(buf.getvalue() + "\n" + tb), [], False))
|
| 160 |
return
|
| 161 |
|
|
|
|
|
|
|
|
|
|
| 162 |
for test_src in test_sources:
|
| 163 |
fn_name = "<unknown>"
|
| 164 |
try:
|
|
|
|
|
|
|
| 165 |
exec(test_src, namespace) # noqa: S102
|
| 166 |
|
|
|
|
| 167 |
fn_name = [
|
| 168 |
ln.split("(")[0].replace("def ", "").strip()
|
| 169 |
for ln in test_src.splitlines()
|
|
|
|
| 174 |
test_results.append({"test_name": fn_name, "passed": True})
|
| 175 |
|
| 176 |
except AssertionError as exc:
|
|
|
|
| 177 |
test_results.append({
|
| 178 |
"test_name": fn_name,
|
| 179 |
"passed": False,
|
|
|
|
| 182 |
),
|
| 183 |
})
|
| 184 |
except Exception: # noqa: BLE001
|
|
|
|
| 185 |
test_results.append({
|
| 186 |
"test_name": fn_name,
|
| 187 |
"passed": False,
|
|
|
|
| 189 |
})
|
| 190 |
|
| 191 |
except Exception: # noqa: BLE001
|
|
|
|
| 192 |
traceback.print_exc(file=buf)
|
| 193 |
finally:
|
| 194 |
sys.stdout, sys.stderr = old_stdout, old_stderr
|
|
|
|
| 197 |
result_queue.put((captured, test_results, had_syntax_error))
|
| 198 |
|
| 199 |
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
def check_syntax(source: str) -> Tuple[bool, str]:
|
| 202 |
"""
|
|
|
|
| 232 |
-------
|
| 233 |
(output_str, test_results, had_syntax_error)
|
| 234 |
"""
|
|
|
|
| 235 |
test_sources = [
|
| 236 |
textwrap.dedent(inspect.getsource(fn))
|
| 237 |
for fn in test_callables
|
|
|
|
| 246 |
proc.start()
|
| 247 |
proc.join(timeout)
|
| 248 |
|
|
|
|
| 249 |
if proc.is_alive():
|
| 250 |
proc.terminate()
|
| 251 |
proc.join(2) # Give it 2s to handle SIGTERM gracefully
|
server/__init__.py
CHANGED
|
@@ -1,11 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
__all__ = ["SWEGymEnvironment"]
|
|
|
|
| 1 |
+
"""TraceFix-RL server components."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
from .tracefix_rl_environment import TraceFixRLEnvironment
|
| 4 |
|
| 5 |
+
__all__ = ["TraceFixRLEnvironment"]
|
|
|
|
|
|
server/app.py
CHANGED
|
@@ -1,10 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""FastAPI entry point for SWE-Gym - Software Engineer Gym."""
|
| 8 |
|
| 9 |
try:
|
| 10 |
from openenv.core.env_server.http_server import create_app
|
|
@@ -15,23 +9,22 @@ except Exception as e: # pragma: no cover
|
|
| 15 |
|
| 16 |
try:
|
| 17 |
from ..models import CodeAction, CodeObservation
|
| 18 |
-
from .
|
| 19 |
except ImportError:
|
| 20 |
import sys
|
| 21 |
from pathlib import Path
|
| 22 |
|
| 23 |
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 24 |
from models import CodeAction, CodeObservation
|
| 25 |
-
from server.
|
| 26 |
|
| 27 |
|
| 28 |
-
# Create the app with web interface and README integration
|
| 29 |
app = create_app(
|
| 30 |
-
|
| 31 |
CodeAction,
|
| 32 |
CodeObservation,
|
| 33 |
-
env_name="
|
| 34 |
-
max_concurrent_envs=1,
|
| 35 |
)
|
| 36 |
|
| 37 |
|
|
|
|
| 1 |
+
"""FastAPI entry point for TraceFix-RL."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
try:
|
| 4 |
from openenv.core.env_server.http_server import create_app
|
|
|
|
| 9 |
|
| 10 |
try:
|
| 11 |
from ..models import CodeAction, CodeObservation
|
| 12 |
+
from .tracefix_rl_environment import TraceFixRLEnvironment
|
| 13 |
except ImportError:
|
| 14 |
import sys
|
| 15 |
from pathlib import Path
|
| 16 |
|
| 17 |
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 18 |
from models import CodeAction, CodeObservation
|
| 19 |
+
from server.tracefix_rl_environment import TraceFixRLEnvironment
|
| 20 |
|
| 21 |
|
|
|
|
| 22 |
app = create_app(
|
| 23 |
+
TraceFixRLEnvironment,
|
| 24 |
CodeAction,
|
| 25 |
CodeObservation,
|
| 26 |
+
env_name="tracefix_rl",
|
| 27 |
+
max_concurrent_envs=1,
|
| 28 |
)
|
| 29 |
|
| 30 |
|
server/{swe_gym_environment.py → tracefix_rl_environment.py}
RENAMED
|
@@ -1,33 +1,23 @@
|
|
| 1 |
-
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""OpenEnv adapter around the SWE-Gym core environment."""
|
| 8 |
|
| 9 |
from openenv.core.env_server.interfaces import Environment
|
| 10 |
from openenv.core.env_server.types import State
|
| 11 |
|
| 12 |
try:
|
| 13 |
-
from ..environment import
|
| 14 |
from ..models import CodeAction, CodeObservation
|
| 15 |
except ImportError:
|
| 16 |
-
from environment import
|
| 17 |
from models import CodeAction, CodeObservation
|
| 18 |
|
| 19 |
|
| 20 |
-
class
|
| 21 |
"""Environment implementation compatible with OpenEnv's server interface."""
|
| 22 |
|
| 23 |
-
# Enable concurrent WebSocket sessions.
|
| 24 |
-
# Set to True if your environment isolates state between instances.
|
| 25 |
-
# When True, multiple WebSocket clients can connect simultaneously, each
|
| 26 |
-
# getting their own environment instance (when using factory mode in app.py).
|
| 27 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 28 |
|
| 29 |
def __init__(self):
|
| 30 |
-
self._gym =
|
| 31 |
self._state = State(episode_id="", step_count=0)
|
| 32 |
|
| 33 |
def reset(self) -> CodeObservation:
|
|
@@ -56,10 +46,4 @@ class SWEGymEnvironment(Environment):
|
|
| 56 |
|
| 57 |
@property
|
| 58 |
def state(self) -> State:
|
| 59 |
-
"""
|
| 60 |
-
Get the current environment state.
|
| 61 |
-
|
| 62 |
-
Returns:
|
| 63 |
-
Current State with episode_id and step_count
|
| 64 |
-
"""
|
| 65 |
return self._state
|
|
|
|
| 1 |
+
"""OpenEnv adapter around the TraceFix-RL core environment."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from openenv.core.env_server.interfaces import Environment
|
| 4 |
from openenv.core.env_server.types import State
|
| 5 |
|
| 6 |
try:
|
| 7 |
+
from ..environment import TraceFixRLGym
|
| 8 |
from ..models import CodeAction, CodeObservation
|
| 9 |
except ImportError:
|
| 10 |
+
from environment import TraceFixRLGym
|
| 11 |
from models import CodeAction, CodeObservation
|
| 12 |
|
| 13 |
|
| 14 |
+
class TraceFixRLEnvironment(Environment):
|
| 15 |
"""Environment implementation compatible with OpenEnv's server interface."""
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 18 |
|
| 19 |
def __init__(self):
|
| 20 |
+
self._gym = TraceFixRLGym()
|
| 21 |
self._state = State(episode_id="", step_count=0)
|
| 22 |
|
| 23 |
def reset(self) -> CodeObservation:
|
|
|
|
| 46 |
|
| 47 |
@property
|
| 48 |
def state(self) -> State:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
return self._state
|