Viraj0112 commited on
Commit
03a907a
·
verified ·
1 Parent(s): 2259499

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +78 -0
  2. README.md +182 -215
  3. __init__.py +14 -0
  4. _aliases.py +21 -0
  5. client.py +185 -0
  6. conftest.py +21 -0
  7. dataset/README.md +20 -0
  8. dataset/__init__.py +1 -0
  9. dataset/generate_swebench_tasks.py +498 -0
  10. dataset/loader.py +111 -0
  11. dataset/prepare_swebench.py +274 -0
  12. dataset/problem_1/buggy.py +7 -0
  13. dataset/problem_1/metadata.json +5 -0
  14. dataset/problem_1/test.py +18 -0
  15. dataset/problem_10/buggy.py +8 -0
  16. dataset/problem_10/helpers.py +2 -0
  17. dataset/problem_10/metadata.json +5 -0
  18. dataset/problem_10/test.py +12 -0
  19. dataset/problem_11/buggy.py +14 -0
  20. dataset/problem_11/metadata.json +5 -0
  21. dataset/problem_11/test.py +17 -0
  22. dataset/problem_12/buggy.py +11 -0
  23. dataset/problem_12/metadata.json +5 -0
  24. dataset/problem_12/test.py +14 -0
  25. dataset/problem_13/buggy.py +10 -0
  26. dataset/problem_13/cache.py +20 -0
  27. dataset/problem_13/metadata.json +5 -0
  28. dataset/problem_13/test.py +13 -0
  29. dataset/problem_14/buggy.py +6 -0
  30. dataset/problem_14/metadata.json +5 -0
  31. dataset/problem_14/test.py +15 -0
  32. dataset/problem_15/buggy.py +4 -0
  33. dataset/problem_15/metadata.json +5 -0
  34. dataset/problem_15/test.py +14 -0
  35. dataset/problem_16/buggy.py +10 -0
  36. dataset/problem_16/helpers.py +3 -0
  37. dataset/problem_16/metadata.json +5 -0
  38. dataset/problem_16/test.py +12 -0
  39. dataset/problem_17/buggy.py +11 -0
  40. dataset/problem_17/metadata.json +5 -0
  41. dataset/problem_17/test.py +11 -0
  42. dataset/problem_18/buggy.py +14 -0
  43. dataset/problem_18/math_utils.py +6 -0
  44. dataset/problem_18/metadata.json +5 -0
  45. dataset/problem_18/test.py +14 -0
  46. dataset/problem_19/buggy.py +36 -0
  47. dataset/problem_19/metadata.json +5 -0
  48. dataset/problem_19/test.py +48 -0
  49. dataset/problem_2/buggy.py +14 -0
  50. dataset/problem_2/metadata.json +5 -0
Dockerfile ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
3
+ FROM ${BASE_IMAGE} AS builder
4
+
5
+ WORKDIR /app
6
+
7
+ # Ensure git and curl are available
8
+ RUN apt-get update && \
9
+ apt-get install -y --no-install-recommends git curl ca-certificates && \
10
+ rm -rf /var/lib/apt/lists/*
11
+
12
+ # Build argument to control whether we're building standalone or in-repo
13
+ ARG BUILD_MODE=in-repo
14
+ ARG ENV_NAME=rl_code_fix_env
15
+
16
+ # Copy environment code (always at root of build context)
17
+ COPY . /app/env
18
+
19
+ # For in-repo builds, openenv is already vendored in the build context
20
+ # For standalone builds, openenv will be installed via pyproject.toml
21
+ WORKDIR /app/env
22
+
23
+ # Ensure uv is available
24
+ RUN if ! command -v uv >/dev/null 2>&1; then \
25
+ curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR=/usr/local/bin sh; \
26
+ fi
27
+
28
+ # Install dependencies using uv sync
29
+ # If uv.lock exists, use it; otherwise resolve on the fly
30
+ RUN --mount=type=cache,target=/root/.cache/uv \
31
+ if [ -f uv.lock ]; then \
32
+ uv sync --frozen --no-install-project --no-editable; \
33
+ else \
34
+ uv sync --no-install-project --no-editable; \
35
+ fi
36
+
37
+ RUN --mount=type=cache,target=/root/.cache/uv \
38
+ if [ -f uv.lock ]; then \
39
+ uv sync --frozen --no-editable; \
40
+ else \
41
+ uv sync --no-editable; \
42
+ fi
43
+
44
+ # Final runtime stage
45
+ FROM ${BASE_IMAGE}
46
+
47
+ # Install curl for health check
48
+ RUN apt-get update && \
49
+ apt-get install -y --no-install-recommends curl && \
50
+ rm -rf /var/lib/apt/lists/*
51
+
52
+ WORKDIR /app
53
+
54
+ # Copy environment code + its in-place virtualenv from builder.
55
+ # Keep the venv at the same path it was created with (/app/env/.venv)
56
+ # to avoid relocation issues and dual-venv path conflicts.
57
+ COPY --from=builder /app/env /app/env
58
+
59
+ # Use the single in-repo venv
60
+ ENV VIRTUAL_ENV="/app/env/.venv"
61
+ ENV PATH="/app/env/.venv/bin:$PATH"
62
+
63
+ # Hermetic runtime: keep imports pinned to repo code + active venv.
64
+ ENV PYTHONPATH="/app/env"
65
+ ENV PYTHONNOUSERSITE="1"
66
+ ENV PYTHONDONTWRITEBYTECODE="1"
67
+
68
+ # Health check
69
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
70
+ CMD curl -f http://localhost:8000/health || exit 1
71
+
72
+ # Expose the application port
73
+ EXPOSE 8000
74
+
75
+ # Run the FastAPI server
76
+ # The module path is constructed to work with the /app/env structure
77
+ ENV ENABLE_WEB_INTERFACE=true
78
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,288 +1,255 @@
1
- # TraceRL Mini Environment for Autonomous Code Fixing
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- This repository packages an OpenEnv-compatible reinforcement learning environment for autonomous Python bug fixing. An agent receives buggy code, can apply unified-diff patches, run the task's tests, inspect logs, and is rewarded for functional progress, reasonable debugging traces, and solving the problem within a step budget.
4
 
5
- ## Environment Overview and Motivation
6
 
7
- The core environment lives in `rl_code_fix_env/` and wraps a code-repair loop around three pieces of functionality:
8
 
9
- 1. Load a bug-fixing task from either a local curated dataset or a materialized SWE-bench Lite workspace.
10
- 2. Let the agent iteratively edit the current `buggy.py` contents with `apply_patch`, then execute the task test file.
11
- 3. Return observations and rewards that make the environment suitable for RL-style training and evaluation.
12
 
13
- The motivation is to benchmark whether an autonomous agent can do more than generate one-shot code. It must:
14
-
15
- - read failing code,
16
- - produce minimal patches,
17
- - use test feedback to refine its fix,
18
- - manage a limited interaction budget,
19
- - and recover from bad intermediate edits.
20
-
21
- This repo also includes a baseline `inference.py` script, containerization for OpenEnv/Hugging Face Spaces deployment, and run logs for a reference baseline.
22
 
23
- ## Repository Layout
 
 
24
 
25
- - `rl_code_fix_env/`: main OpenEnv package.
26
- - `rl_code_fix_env/src/environment/environment.py`: core RL environment logic.
27
- - `rl_code_fix_env/src/reward/`: reward shaping and trace scoring.
28
- - `rl_code_fix_env/src/sandbox/`: unified-diff patching and test execution sandbox.
29
- - `rl_code_fix_env/dataset/`: local bug-fixing tasks and metadata.
30
- - `rl_code_fix_env/server/`: FastAPI/OpenEnv server and Dockerfile.
31
- - `rl_code_fix_env/inference.py`: baseline inference agent.
32
- - `logs.md`: recorded baseline run output.
33
 
34
- ## Action Space
 
35
 
36
- The action model is defined in `rl_code_fix_env/models.py` as:
 
 
 
 
 
37
 
38
- ```python
39
- CodeFixerAction(
40
- type: str,
41
- payload: Optional[str] = None,
42
- )
43
  ```
44
 
45
- Supported action types:
46
-
47
- - `apply_patch`: `payload` is a unified diff patch. The environment fuzzily applies hunks to the current code string.
48
- - `run_tests`: executes the task's `test.py` and updates pass/fail state and logs.
49
- - `get_logs`: returns the most recent logs without changing code.
50
-
51
- Practical meaning:
52
 
53
- - `apply_patch` is the editing action.
54
- - `run_tests` is the feedback action.
55
- - `get_logs` is a cheap inspection action when the agent wants the last failure output again.
56
 
57
- ## Observation Space
58
 
59
- The observation model is also defined in `rl_code_fix_env/models.py`:
60
-
61
- ```python
62
- CodeFixerObservation(
63
- code: str = "",
64
- logs: Optional[str] = None,
65
- test_score: float = 0.0,
66
- total_tests: int = 1,
67
- steps: int = 0,
68
- done: bool = False,
69
- reward: Optional[float] = None,
70
- )
71
  ```
72
 
73
- Field meanings:
74
-
75
- - `code`: the current patched source code under repair.
76
- - `logs`: latest pytest output or startup/fallback messages.
77
- - `test_score`: normalized functional score. In the current local tasks it is `1.0` for pass and `0.0` for fail.
78
- - `total_tests`: number of task test files tracked by the environment. Current local tasks use a single target test file.
79
- - `steps`: number of patch actions consumed so far.
80
- - `done`: episode termination flag.
81
- - `reward`: latest reward returned by the environment wrapper.
82
 
83
- ## Reward Design
84
 
85
- The reward is computed in `rl_code_fix_env/src/reward/reward.py`:
 
 
86
 
87
- ```text
88
- reward =
89
- 0.7 * functional_reward
90
- + 0.2 * trace_reward
91
- + 0.1 * quality_reward
92
- - efficiency_penalty
93
  ```
94
 
95
- Where:
 
 
 
96
 
97
- - `functional_reward = test_score`
98
- - `trace_reward = score_trace(trace_obj)`
99
- - `quality_reward = 1.0` when non-empty code exists, else `0.0`
100
- - `efficiency_penalty = 0.05 * (steps_taken / max_steps)`
101
 
102
- If all tests pass, the environment overrides the reward to `1.0`.
103
 
104
- ## Task Descriptions and Expected Difficulty Levels
105
 
106
- ### Official competition-facing task mapping
 
 
 
107
 
108
- The current local fallback dataset exposes one canonical task per difficulty through `get_hardcoded_task(...)`:
109
 
110
- | Difficulty | Problem ID | Description | Bug type | Expected steps |
111
- | --- | --- | --- | --- | --- |
112
- | Easy | `problem_1` | Reverse words while normalizing repeated spaces | `string-splitting` | 1 |
113
- | Medium | `problem_10` | Rotate a matrix 90 degrees clockwise | `matrix-transformation` | 1 |
114
- | Hard | `problem_13` | Preserve recency correctly in an LRU cache | `state-logic` | 2 |
115
 
116
- Canonical task details:
 
117
 
118
- - `easy`:
119
- The buggy code uses `text.split(" ")`, which preserves empty tokens for repeated spaces. The fix is a small normalization change.
120
- - `medium`:
121
- The code transposes the matrix and then reverses rows in the wrong direction, producing a counter-clockwise rotation.
122
- - `hard`:
123
- The visible task calls into `cache.py`, where `LRUCache.get()` fails to refresh recency. This is stateful and effectively multi-file reasoning.
124
 
125
- ### Full local dataset coverage
 
126
 
127
- The local dataset currently contains 23 problems:
 
 
128
 
129
- - `easy`: 8 tasks
130
- - `medium`: 9 tasks
131
- - `hard`: 6 tasks
132
 
133
- Bug patterns represented across the dataset include:
 
 
 
 
134
 
135
- - whitespace and string normalization
136
- - off-by-one and boundary-condition mistakes
137
- - incorrect matrix and sorting transformations
138
- - recursion and exception-handling bugs
139
- - stateful cache logic and multi-bug hard tasks
140
 
141
- ### Difficulty interpretation
 
 
142
 
143
- - `easy`: usually a single-line or single-concept bug with direct test feedback.
144
- - `medium`: often requires understanding data transformation logic or helper-module behavior.
145
- - `hard`: commonly involves state, multi-step reasoning, or fixes that span more than one conceptual location.
 
 
 
 
146
 
147
- ## Episode Flow
 
 
 
 
148
 
149
- 1. `reset()` selects a difficulty.
150
- 2. The environment loads the buggy code, test path, workspace path, and zeroed metrics.
151
- 3. The agent alternates between `apply_patch`, `run_tests`, and optional `get_logs`.
152
- 4. The episode ends when all tests pass or the step budget is exhausted.
153
 
154
- By default, the server cycles through `easy`, `medium`, and `hard` on reset. You can force a specific difficulty with `TRACERL_TASK=easy`, `TRACERL_TASK=medium`, or `TRACERL_TASK=hard`.
155
 
156
- ## Data Sources
157
 
158
- `CodeEnv` defaults to `TASK_SOURCE=swebench`. If SWE-bench Lite task materialization is unavailable, it falls back to the local curated dataset when `SWEBENCH_FALLBACK_LOCAL=1` is enabled, which is the current default behavior.
 
159
 
160
- Expected SWE-bench Lite workspace layout:
 
161
 
162
- ```text
163
- rl_code_fix_env/dataset/swebench_lite_tasks/<instance_id>/
164
- buggy.py
165
- test.py
166
  ```
167
 
168
- ## Setup Instructions
169
 
170
- ### Local Python setup
171
 
172
- From the repository root:
173
 
174
- ```bash
175
- cd rl_code_fix_env
176
- uv sync
177
- ```
178
-
179
- If you are not using `uv`, install the shared dependencies from the repository root:
180
-
181
- ```bash
182
- pip install -r requirements.txt
 
 
183
  ```
184
 
185
- ### Required environment variables for inference
 
 
 
186
 
187
- The baseline agent expects:
188
 
189
- ```bash
190
- API_BASE_URL=<openai-compatible-endpoint>
191
- MODEL_NAME=<model-id>
192
- HF_TOKEN=<api-key>
193
- ```
194
 
195
- Useful optional variables:
196
-
197
- ```bash
198
- ENV_URL=http://localhost:8000
199
- TRACERL_TASK=easy
200
- TASK_SOURCE=swebench
201
- SWEBENCH_FALLBACK_LOCAL=1
202
- MAX_STEPS=10
203
- TEMPERATURE=0.2
204
- MAX_TOKENS=2048
205
- SUCCESS_THRESHOLD=1.0
206
- MAX_RETRIES=3
207
  ```
208
 
209
- ## Usage Instructions
210
 
211
- ### Run the environment server locally
212
-
213
- ```bash
214
- cd rl_code_fix_env
215
- uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
 
 
 
 
 
 
 
 
 
216
  ```
217
 
218
- Alternative entry point:
219
 
220
- ```bash
221
- cd rl_code_fix_env
222
- uv run --project . server
223
- ```
224
-
225
- ### Run the baseline inference agent
226
 
227
- Open a second terminal:
228
 
229
  ```bash
230
- cd rl_code_fix_env
231
- python inference.py
232
  ```
233
 
234
- The script emits machine-parseable lines in this format:
 
 
 
 
235
 
236
- ```text
237
- [START] task=<task_name> env=<benchmark> model=<model_name>
238
- [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
239
- [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
240
- ```
241
-
242
- ### Build and run with Docker
243
 
244
- From `rl_code_fix_env/`:
245
 
246
  ```bash
247
- docker build -t rl_code_fix_env-env:latest -f server/Dockerfile .
248
- docker run -p 8000:8000 rl_code_fix_env-env:latest
249
  ```
250
 
251
- ### OpenEnv / Hugging Face Spaces deployment
252
 
253
- From `rl_code_fix_env/`:
254
-
255
- ```bash
256
- openenv push
257
  ```
258
-
259
- The package is configured as a FastAPI OpenEnv space via `openenv.yaml`.
260
-
261
- ## Baseline Performance Scores
262
-
263
- The current recorded baseline in `logs.md` ran one episode each for `easy`, `medium`, and `hard` using model `qwen/qwen3-coder-480b-a35b-instruct`.
264
-
265
- | Task | Success | Steps | Final score | Reward trace | Cumulative reward |
266
- | --- | --- | --- | --- | --- | --- |
267
- | Easy | `false` | 10 | 0.00 | `0.14,0.13,0.12,0.11,0.10,0.09,0.08,0.07,0.06,0.05` | 0.95 |
268
- | Medium | `false` | 10 | 0.00 | `0.14,0.13,0.12,0.11,0.10,0.09,0.08,0.07,0.06,0.05` | 0.95 |
269
- | Hard | `false` | 10 | 0.00 | `0.14,0.13,0.12,0.11,0.10,0.09,0.08,0.07,0.06,0.05` | 0.95 |
270
-
271
- Aggregate baseline summary:
272
-
273
- - episodes evaluated: 3
274
- - success rate: `0/3`
275
- - mean final score: `0.00`
276
- - mean cumulative reward: `0.95`
277
-
278
- Interpretation:
279
-
280
- - The baseline agent produced syntactically plausible patches and collected small shaped rewards.
281
- - It did not achieve a passing test score on any recorded task.
282
- - The current baseline should be treated as a starting point rather than a competitive upper bound.
283
-
284
- ## Notes and Caveats
285
-
286
- - The local fallback tasks currently use one target test file per problem, so `test_score` is binary.
287
- - Patch application uses `unidiff` plus fuzzy matching from `diff-match-patch`, which makes the environment more tolerant to slightly stale context.
288
- - Test execution prefers Docker sandboxing, but falls back to direct `pytest` execution when Docker is unavailable.
 
1
+ ---
2
+ title: Rl Code Fix Env Environment Server
3
+ emoji: "🚀"
4
+ colorFrom: green
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ ---
13
 
14
+ # Rl Code Fix Env Environment
15
 
16
+ A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
17
 
18
+ ## Quick Start
19
 
20
+ The simplest way to use the Rl Code Fix Env environment is through the `RlCodeFixEnv` class:
 
 
21
 
22
+ ```python
23
+ from rl_code_fix_env import RlCodeFixAction, RlCodeFixEnv
 
 
 
 
 
 
 
24
 
25
+ try:
26
+ # Create environment from Docker image
27
+ rl_code_fix_envenv = RlCodeFixEnv.from_docker_image("rl_code_fix_env-env:latest")
28
 
29
+ # Reset
30
+ result = rl_code_fix_envenv.reset()
31
+ print(f"Reset: {result.observation.echoed_message}")
 
 
 
 
 
32
 
33
+ # Send multiple messages
34
+ messages = ["Hello, World!", "Testing echo", "Final message"]
35
 
36
+ for msg in messages:
37
+ result = rl_code_fix_envenv.step(RlCodeFixAction(message=msg))
38
+ print(f"Sent: '{msg}'")
39
+ print(f" Echoed: '{result.observation.echoed_message}'")
40
+ print(f" Length: {result.observation.message_length}")
41
+ print(f" Reward: {result.reward}")
42
 
43
+ finally:
44
+ # Always clean up
45
+ rl_code_fix_envenv.close()
 
 
46
  ```
47
 
48
+ That's it! The `RlCodeFixEnv.from_docker_image()` method handles:
49
+ - Starting the Docker container
50
+ - Waiting for the server to be ready
51
+ - Connecting to the environment
52
+ - Container cleanup when you call `close()`
 
 
53
 
54
+ ## Building the Docker Image
 
 
55
 
56
+ Before using the environment, you need to build the Docker image:
57
 
58
+ ```bash
59
+ # From project root
60
+ docker build -t rl_code_fix_env-env:latest -f server/Dockerfile .
 
 
 
 
 
 
 
 
 
61
  ```
62
 
63
+ ## Deploying to Hugging Face Spaces
 
 
 
 
 
 
 
 
64
 
65
+ You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
66
 
67
+ ```bash
68
+ # From the environment directory (where openenv.yaml is located)
69
+ openenv push
70
 
71
+ # Or specify options
72
+ openenv push --namespace my-org --private
 
 
 
 
73
  ```
74
 
75
+ The `openenv push` command will:
76
+ 1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
77
+ 2. Prepare a custom build for Hugging Face Docker space (enables web interface)
78
+ 3. Upload to Hugging Face (ensuring you're logged in)
79
 
80
+ ### Prerequisites
 
 
 
81
 
82
+ - Authenticate with Hugging Face: The command will prompt for login if not already authenticated
83
 
84
+ ### Options
85
 
86
+ - `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
87
+ - `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
88
+ - `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
89
+ - `--private`: Deploy the space as private (default: public)
90
 
91
+ ### Examples
92
 
93
+ ```bash
94
+ # Push to your personal namespace (defaults to username/env-name from openenv.yaml)
95
+ openenv push
 
 
96
 
97
+ # Push to a specific repository
98
+ openenv push --repo-id my-org/my-env
99
 
100
+ # Push with a custom base image
101
+ openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
 
 
 
 
102
 
103
+ # Push as a private space
104
+ openenv push --private
105
 
106
+ # Combine options
107
+ openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
108
+ ```
109
 
110
+ After deployment, your space will be available at:
111
+ `https://huggingface.co/spaces/<repo-id>`
 
112
 
113
+ The deployed space includes:
114
+ - **Web Interface** at `/web` - Interactive UI for exploring the environment
115
+ - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
116
+ - **Health Check** at `/health` - Container health monitoring
117
+ - **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
118
 
119
+ ## Environment Details
 
 
 
 
120
 
121
+ ### Action
122
+ **RlCodeFixAction**: Contains a single field
123
+ - `message` (str) - The message to echo back
124
 
125
+ ### Observation
126
+ **RlCodeFixObservation**: Contains the echo response and metadata
127
+ - `echoed_message` (str) - The message echoed back
128
+ - `message_length` (int) - Length of the message
129
+ - `reward` (float) - Reward based on message length (length 0.1)
130
+ - `done` (bool) - Always False for echo environment
131
+ - `metadata` (dict) - Additional info like step count
132
 
133
+ ### Reward
134
+ The reward is calculated as: `message_length 0.1`
135
+ - "Hi" reward: 0.2
136
+ - "Hello, World!" reward: 1.3
137
+ - Empty message reward: 0.0
138
 
139
+ ## Advanced Usage
 
 
 
140
 
141
+ ### Connecting to an Existing Server
142
 
143
+ If you already have a Rl Code Fix Env environment server running, you can connect directly:
144
 
145
+ ```python
146
+ from rl_code_fix_env import RlCodeFixEnv
147
 
148
+ # Connect to existing server
149
+ rl_code_fix_envenv = RlCodeFixEnv(base_url="<ENV_HTTP_URL_HERE>")
150
 
151
+ # Use as normal
152
+ result = rl_code_fix_envenv.reset()
153
+ result = rl_code_fix_envenv.step(RlCodeFixAction(message="Hello!"))
 
154
  ```
155
 
156
+ Note: When connecting to an existing server, `rl_code_fix_envenv.close()` will NOT stop the server.
157
 
158
+ ### Using the Context Manager
159
 
160
+ The client supports context manager usage for automatic connection management:
161
 
162
+ ```python
163
+ from rl_code_fix_env import RlCodeFixAction, RlCodeFixEnv
164
+
165
+ # Connect with context manager (auto-connects and closes)
166
+ with RlCodeFixEnv(base_url="http://localhost:8000") as env:
167
+ result = env.reset()
168
+ print(f"Reset: {result.observation.echoed_message}")
169
+ # Multiple steps with low latency
170
+ for msg in ["Hello", "World", "!"]:
171
+ result = env.step(RlCodeFixAction(message=msg))
172
+ print(f"Echoed: {result.observation.echoed_message}")
173
  ```
174
 
175
+ The client uses WebSocket connections for:
176
+ - **Lower latency**: No HTTP connection overhead per request
177
+ - **Persistent session**: Server maintains your environment state
178
+ - **Efficient for episodes**: Better for many sequential steps
179
 
180
+ ### Concurrent WebSocket Sessions
181
 
182
+ The server supports multiple concurrent WebSocket connections. To enable this,
183
+ modify `server/app.py` to use factory mode:
 
 
 
184
 
185
+ ```python
186
+ # In server/app.py - use factory mode for concurrent sessions
187
+ app = create_app(
188
+ RlCodeFixEnvironment, # Pass class, not instance
189
+ RlCodeFixAction,
190
+ RlCodeFixObservation,
191
+ max_concurrent_envs=4, # Allow 4 concurrent sessions
192
+ )
 
 
 
 
193
  ```
194
 
195
+ Then multiple clients can connect simultaneously:
196
 
197
+ ```python
198
+ from rl_code_fix_env import RlCodeFixAction, RlCodeFixEnv
199
+ from concurrent.futures import ThreadPoolExecutor
200
+
201
+ def run_episode(client_id: int):
202
+ with RlCodeFixEnv(base_url="http://localhost:8000") as env:
203
+ result = env.reset()
204
+ for i in range(10):
205
+ result = env.step(RlCodeFixAction(message=f"Client {client_id}, step {i}"))
206
+ return client_id, result.observation.message_length
207
+
208
+ # Run 4 episodes concurrently
209
+ with ThreadPoolExecutor(max_workers=4) as executor:
210
+ results = list(executor.map(run_episode, range(4)))
211
  ```
212
 
213
+ ## Development & Testing
214
 
215
+ ### Direct Environment Testing
 
 
 
 
 
216
 
217
+ Test the environment logic directly without starting the HTTP server:
218
 
219
  ```bash
220
+ # From the server directory
221
+ python3 server/rl_code_fix_env_environment.py
222
  ```
223
 
224
+ This verifies that:
225
+ - Environment resets correctly
226
+ - Step executes actions properly
227
+ - State tracking works
228
+ - Rewards are calculated correctly
229
 
230
+ ### Running Locally
 
 
 
 
 
 
231
 
232
+ Run the server locally for development:
233
 
234
  ```bash
235
+ uvicorn server.app:app --reload
 
236
  ```
237
 
238
+ ## Project Structure
239
 
 
 
 
 
240
  ```
241
+ rl_code_fix_env/
242
+ .dockerignore # Docker build exclusions
243
+ __init__.py # Module exports
244
+ README.md # This file
245
+ openenv.yaml # OpenEnv manifest
246
+ pyproject.toml # Project metadata and dependencies
247
+ uv.lock # Locked dependencies (generated)
248
+ client.py # RlCodeFixEnv client
249
+ models.py # Action and Observation models
250
+ server/
251
+ __init__.py # Server module exports
252
+ rl_code_fix_env_environment.py # Core environment logic
253
+ app.py # FastAPI application (HTTP + WebSocket endpoints)
254
+ Dockerfile # Container image definition
255
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Rl Code Fix Env Environment."""
8
+
9
+ from .models import CodeFixerAction, CodeFixerObservation
10
+
11
+ __all__ = [
12
+ "CodeFixerAction",
13
+ "CodeFixerObservation",
14
+ ]
_aliases.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import importlib
3
+ from pathlib import Path
4
+
5
+ _REPO_ROOT = str(Path(__file__).parent)
6
+ if _REPO_ROOT not in sys.path:
7
+ sys.path.insert(0, _REPO_ROOT)
8
+
9
+ import dataset as _real_dataset
10
+
11
+ sys.modules.setdefault("src.dataset", _real_dataset)
12
+
13
+ import pkgutil
14
+ for _pkg in pkgutil.iter_modules(_real_dataset.__path__):
15
+ _full = f"dataset.{_pkg.name}"
16
+ _alias = f"src.dataset.{_pkg.name}"
17
+ try:
18
+ _mod = importlib.import_module(_full)
19
+ sys.modules.setdefault(_alias, _mod)
20
+ except Exception:
21
+ pass
client.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Code Fixer Environment Client."""
8
+
9
+ import asyncio
10
+ import inspect
11
+ import logging
12
+ from typing import Dict
13
+
14
+ from openenv.core import EnvClient
15
+ from openenv.core.client_types import StepResult
16
+ from openenv.core.env_server.types import State
17
+
18
+ from rl_code_fix_env.models import CodeFixerAction, CodeFixerObservation
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+ class CodeFixerEnv(
23
+ EnvClient[CodeFixerAction, CodeFixerObservation, State]
24
+ ):
25
+ """
26
+ Client for the Code Fixer Environment.
27
+
28
+ This client maintains a persistent WebSocket connection to the environment server,
29
+ enabling efficient multi-step interactions with lower latency.
30
+ Each client instance has its own dedicated environment session on the server.
31
+
32
+ Example:
33
+ >>> # Connect to a running server
34
+ >>> with CodeFixerEnv(base_url="http://localhost:8000") as client:
35
+ ... result = client.reset()
36
+ ... print(result.observation.code)
37
+ ...
38
+ ... result = client.step(CodeFixerAction(type="run_tests"))
39
+ ... print(result.observation.test_passed)
40
+
41
+ Example with Docker:
42
+ >>> # Automatically start container and connect
43
+ >>> client = CodeFixerEnv.from_docker_image("code_fixer-env:latest")
44
+ >>> try:
45
+ ... result = client.reset()
46
+ ... result = client.step(CodeFixerAction(type="run_tests"))
47
+ ... finally:
48
+ ... client.close()
49
+ """
50
+
51
+ def __init__(self, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+ self._loop = asyncio.new_event_loop()
54
+ # Store init args for reconnection
55
+ self._init_args = args
56
+ self._init_kwargs = kwargs
57
+
58
+ def _run_sync(self, result):
59
+ """Run coroutine results on this client's dedicated event loop."""
60
+ if inspect.iscoroutine(result):
61
+ return self._loop.run_until_complete(result)
62
+ return result
63
+
64
+ def _reconnect(self) -> None:
65
+ """
66
+ Tear down the dead event loop and WebSocket connection, then
67
+ re-initialise so the next call works cleanly.
68
+
69
+ Called automatically by reset() and step() when a 1011 / timeout
70
+ error is detected after an idle period.
71
+ """
72
+ log.warning("[CodeFixerEnv] WebSocket timed out reconnecting...")
73
+ # Close the old loop gracefully
74
+ try:
75
+ self._run_sync(super().close())
76
+ except Exception:
77
+ pass
78
+ if not self._loop.is_closed():
79
+ self._loop.close()
80
+
81
+ # Re-initialise: fresh loop + fresh base-class state
82
+ self._loop = asyncio.new_event_loop()
83
+ super().__init__(*self._init_args, **self._init_kwargs)
84
+ log.warning("[CodeFixerEnv] Reconnected successfully.")
85
+
86
+ @staticmethod
87
+ def _is_reconnectable_ws_error(exc: Exception) -> bool:
88
+ err = str(exc).lower()
89
+ reconnect_markers = (
90
+ "1011",
91
+ "1006",
92
+ "keepalive",
93
+ "timed out",
94
+ "closed",
95
+ "close frame",
96
+ "connection closed",
97
+ "connectionclosed",
98
+ "websocket",
99
+ )
100
+ return any(marker in err for marker in reconnect_markers)
101
+
102
+ def reset(self):
103
+ """Reset the environment auto-reconnects if the WebSocket died."""
104
+ try:
105
+ return self._run_sync(super().reset())
106
+ except Exception as exc:
107
+ if self._is_reconnectable_ws_error(exc):
108
+ self._reconnect()
109
+ return self._run_sync(super().reset()) # one retry
110
+ raise
111
+
112
+ def step(self, action: CodeFixerAction):
113
+ """Execute a step auto-reconnects if the WebSocket died."""
114
+ try:
115
+ return self._run_sync(super().step(action))
116
+ except Exception as exc:
117
+ if self._is_reconnectable_ws_error(exc):
118
+ self._reconnect()
119
+ return self._run_sync(super().step(action)) # one retry
120
+ raise
121
+
122
+ def close(self):
123
+ """Close client resources and the dedicated event loop safely."""
124
+ try:
125
+ self._run_sync(super().close())
126
+ finally:
127
+ if not self._loop.is_closed():
128
+ self._loop.close()
129
+
130
+ def _step_payload(self, action: CodeFixerAction) -> Dict:
131
+ """
132
+ Convert CodeFixerAction to JSON payload for step message.
133
+
134
+ Args:
135
+ action: CodeFixerAction instance
136
+
137
+ Returns:
138
+ Dictionary representation suitable for JSON encoding
139
+ """
140
+ return {
141
+ "type": action.type,
142
+ "payload": action.payload,
143
+ }
144
+
145
+ def _parse_result(self, payload: Dict) -> StepResult[CodeFixerObservation]:
146
+ """
147
+ Parse server response into StepResult[CodeFixerObservation].
148
+
149
+ Args:
150
+ payload: JSON response data from server
151
+
152
+ Returns:
153
+ StepResult with CodeFixerObservation
154
+ """
155
+ obs_data = payload.get("observation", {})
156
+ observation = CodeFixerObservation(
157
+ code=obs_data.get("code", ""),
158
+ logs=obs_data.get("logs"),
159
+ test_score=float(obs_data.get("test_score", 0.0)),
160
+ total_tests=obs_data.get("total_tests", 1),
161
+ steps=obs_data.get("steps", 0),
162
+ done=obs_data.get("done", payload.get("done", False)),
163
+ reward=obs_data.get("reward", payload.get("reward")),
164
+ )
165
+
166
+ return StepResult(
167
+ observation=observation,
168
+ reward=payload.get("reward"),
169
+ done=payload.get("done", False),
170
+ )
171
+
172
+ def _parse_state(self, payload: Dict) -> State:
173
+ """
174
+ Parse server response into State object.
175
+
176
+ Args:
177
+ payload: JSON response from state request
178
+
179
+ Returns:
180
+ State object with episode_id and step_count
181
+ """
182
+ return State(
183
+ episode_id=payload.get("episode_id"),
184
+ step_count=payload.get("step_count", 0),
185
+ )
conftest.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import importlib
3
+ from pathlib import Path
4
+
5
+ _REPO_ROOT = str(Path(__file__).parent)
6
+ if _REPO_ROOT not in sys.path:
7
+ sys.path.insert(0, _REPO_ROOT)
8
+
9
+ import dataset as _real_dataset
10
+
11
+ sys.modules.setdefault("src.dataset", _real_dataset)
12
+
13
+ import pkgutil
14
+ for _pkg in pkgutil.iter_modules(_real_dataset.__path__):
15
+ _full = f"dataset.{_pkg.name}"
16
+ _alias = f"src.dataset.{_pkg.name}"
17
+ try:
18
+ _mod = importlib.import_module(_full)
19
+ sys.modules.setdefault(_alias, _mod)
20
+ except Exception:
21
+ pass
dataset/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Buggy Training Dataset
2
+
3
+ This dataset is organized as:
4
+
5
+ - `problem_x/buggy.py`: intentionally buggy implementation
6
+ - `problem_x/test.py`: correctness tests that should fail before fixes
7
+ - optional extra modules (`helpers.py`, `cache.py`, etc.) to support multi-file bug fixing
8
+
9
+ Current problems: `problem_1` to `problem_18`.
10
+
11
+ Bug patterns included:
12
+ - off-by-one errors
13
+ - boundary condition mistakes
14
+ - incorrect sorting direction
15
+ - exception handling mistakes
16
+ - state/recency bugs in cache logic
17
+ - recursive base-case bugs
18
+ - parsing and whitespace normalization issues
19
+ - order-preservation regressions
20
+ - matrix transformation direction errors
dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Dataset loading modules."""
dataset/generate_swebench_tasks.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate synthetic SWE-bench style tasks for testing.
3
+
4
+ This creates tasks that mimic the SWE-bench format:
5
+ - instance_id/buggy.py - the buggy code
6
+ - instance_id/test.py - test file
7
+ - instance_id/metadata.json - metadata
8
+
9
+ Usage:
10
+ python -m dataset.generate_swebench_tasks [--count N]
11
+ """
12
+
13
+ import argparse
14
+ import json
15
+ import random
16
+ from pathlib import Path
17
+
18
+
19
+ # Sample SWE-bench style problems
20
+ SWE_BENCH_PROBLEMS = [
21
+ {
22
+ "instance_id": "django__django-11098",
23
+ "repo": "django/django",
24
+ "problem": "Fix the user creation form validation error",
25
+ "buggy_code": '''from django import forms
26
+ from django.contrib.auth.models import User
27
+
28
+ class UserCreationForm(forms.ModelForm):
29
+ """Form for creating new users."""
30
+ password1 = forms.CharField(widget=forms.PasswordInput)
31
+ password2 = forms.CharField(widget=forms.PasswordInput)
32
+
33
+ class Meta:
34
+ model = User
35
+ fields = ('username', 'email')
36
+
37
+ def clean(self):
38
+ cleaned_data = super().clean()
39
+ password1 = cleaned_data.get('password1')
40
+ password2 = cleaned_data.get('password2')
41
+
42
+ # BUG: This comparison is case-sensitive but should be case-insensitive
43
+ if password1 != password2:
44
+ raise forms.ValidationError("Passwords don't match")
45
+
46
+ return cleaned_data
47
+
48
+ def save(self, commit=True):
49
+ user = super().save(commit=False)
50
+ user.set_password(self.cleaned_data['password1'])
51
+ if commit:
52
+ user.save()
53
+ return user
54
+ ''',
55
+ "test_code": '''import unittest
56
+ from buggy import UserCreationForm
57
+
58
+ class TestUserCreationForm(unittest.TestCase):
59
+ def test_password_matching(self):
60
+ """Test that matching passwords pass validation."""
61
+ form = UserCreationForm(data={
62
+ 'username': 'testuser',
63
+ 'email': 'test@example.com',
64
+ 'password1': 'TestPass123',
65
+ 'password2': 'TestPass123',
66
+ })
67
+ self.assertTrue(form.is_valid())
68
+
69
+ def test_password_mismatch(self):
70
+ """Test that mismatched passwords fail validation."""
71
+ form = UserCreationForm(data={
72
+ 'username': 'testuser',
73
+ 'email': 'test@example.com',
74
+ 'password1': 'TestPass123',
75
+ 'password2': 'testpass123', # Different case
76
+ })
77
+ self.assertFalse(form.is_valid())
78
+ self.assertIn('passwords', str(form.errors).lower())
79
+ ''',
80
+ },
81
+ {
82
+ "instance_id": "flask__flask-1048",
83
+ "repo": "pallets/flask",
84
+ "problem": "Fix JSON encoding for datetime objects",
85
+ "buggy_code": '''import json
86
+ from datetime import datetime, date
87
+
88
+ class JSONEncoder(json.JSONEncoder):
89
+ """Custom JSON encoder for Flask."""
90
+
91
+ def default(self, obj):
92
+ # BUG: Missing handling for datetime objects
93
+ if isinstance(obj, date):
94
+ return obj.isoformat()
95
+ return super().default(obj)
96
+
97
+ def to_json(obj):
98
+ """Convert object to JSON string."""
99
+ return json.dumps(obj, cls=JSONEncoder)
100
+ ''',
101
+ "test_code": '''import unittest
102
+ from datetime import datetime
103
+ from buggy import to_json
104
+
105
+ class TestJSONEncoding(unittest.TestCase):
106
+ def test_encode_datetime(self):
107
+ """Test that datetime objects are properly encoded."""
108
+ dt = datetime(2024, 1, 15, 10, 30, 0)
109
+ result = to_json({'timestamp': dt})
110
+ self.assertIn('2024-01-15', result)
111
+ self.assertIn('10:30:00', result)
112
+
113
+ def test_encode_date(self):
114
+ """Test that date objects are properly encoded."""
115
+ d = date(2024, 1, 15)
116
+ result = to_json({'date': d})
117
+ self.assertIn('2024-01-15', result)
118
+ ''',
119
+ },
120
+ {
121
+ "instance_id": "requests__requests-2875",
122
+ "repo": "psf/requests",
123
+ "problem": "Fix cookie domain matching",
124
+ "buggy_code": '''import re
125
+ from urllib.parse import urlparse
126
+
127
+ def match_cookie_domain(cookie_domain, request_domain):
128
+ """Check if cookie domain matches request domain."""
129
+ # BUG: Should handle leading dots differently
130
+ # .example.com should match sub.example.com but not example.com
131
+ cookie_domain = cookie_domain.lower()
132
+ request_domain = request_domain.lower()
133
+
134
+ if cookie_domain.startswith('.'):
135
+ return request_domain.endswith(cookie_domain)
136
+
137
+ return cookie_domain == request_domain
138
+ ''',
139
+ "test_code": '''import unittest
140
+ from buggy import match_cookie_domain
141
+
142
+ class TestCookieDomain(unittest.TestCase):
143
+ def test_exact_match(self):
144
+ """Test exact domain matching."""
145
+ self.assertTrue(match_cookie_domain('example.com', 'example.com'))
146
+
147
+ def test_subdomain_with_dot(self):
148
+ """Test subdomain matching with leading dot."""
149
+ # .example.com should match sub.example.com
150
+ self.assertTrue(match_cookie_domain('.example.com', 'sub.example.com'))
151
+ self.assertFalse(match_cookie_domain('.example.com', 'example.com'))
152
+
153
+ def test_different_domains(self):
154
+ """Test different domains don't match."""
155
+ self.assertFalse(match_cookie_domain('example.com', 'other.com'))
156
+ ''',
157
+ },
158
+ {
159
+ "instance_id": "numpy__numpy-10825",
160
+ "repo": "numpy/numpy",
161
+ "problem": "Fix array concatenation edge case",
162
+ "buggy_code": '''import numpy as np
163
+
164
+ def concatenate_arrays(*arrays):
165
+ """Concatenate multiple arrays along axis 0."""
166
+ if not arrays:
167
+ return np.array([])
168
+
169
+ # BUG: Should handle None arrays gracefully
170
+ result = arrays[0]
171
+ for arr in arrays[1:]:
172
+ result = np.concatenate([result, arr])
173
+
174
+ return result
175
+ ''',
176
+ "test_code": '''import unittest
177
+ import numpy as np
178
+ from buggy import concatenate_arrays
179
+
180
+ class TestArrayConcatenation(unittest.TestCase):
181
+ def test_basic_concatenation(self):
182
+ """Test basic array concatenation."""
183
+ a = np.array([1, 2, 3])
184
+ b = np.array([4, 5, 6])
185
+ result = concatenate_arrays(a, b)
186
+ np.testing.assert_array_equal(result, np.array([1, 2, 3, 4, 5, 6]))
187
+
188
+ def test_empty_input(self):
189
+ """Test empty input returns empty array."""
190
+ result = concatenate_arrays()
191
+ self.assertEqual(len(result), 0)
192
+
193
+ def test_single_array(self):
194
+ """Test single array passes through."""
195
+ a = np.array([1, 2, 3])
196
+ result = concatenate_arrays(a)
197
+ np.testing.assert_array_equal(result, a)
198
+ ''',
199
+ },
200
+ {
201
+ "instance_id": "pandas__pandas-15230",
202
+ "repo": "pandas-dev/pandas",
203
+ "problem": "Fix DataFrame groupby aggregation",
204
+ "buggy_code": '''import pandas as pd
205
+
206
+ def group_and_aggregate(df, group_col, agg_col, agg_func='mean'):
207
+ """Group DataFrame and aggregate."""
208
+ # BUG: Should handle non-numeric columns gracefully
209
+ if agg_func == 'mean':
210
+ return df.groupby(group_col)[agg_col].mean()
211
+ elif agg_func == 'sum':
212
+ return df.groupby(group_col)[agg_col].sum()
213
+ elif agg_func == 'count':
214
+ return df.groupby(group_col)[agg_col].count()
215
+ else:
216
+ raise ValueError(f"Unknown aggregation function: {agg_func}")
217
+ ''',
218
+ "test_code": '''import unittest
219
+ import pandas as pd
220
+ from buggy import group_and_aggregate
221
+
222
+ class TestGroupBy(unittest.TestCase):
223
+ def test_mean_aggregation(self):
224
+ """Test mean aggregation."""
225
+ df = pd.DataFrame({
226
+ 'category': ['A', 'A', 'B', 'B'],
227
+ 'value': [1, 2, 3, 4]
228
+ })
229
+ result = group_and_aggregate(df, 'category', 'value', 'mean')
230
+ self.assertEqual(result['A'], 1.5)
231
+ self.assertEqual(result['B'], 3.5)
232
+
233
+ def test_sum_aggregation(self):
234
+ """Test sum aggregation."""
235
+ df = pd.DataFrame({
236
+ 'category': ['A', 'A', 'B'],
237
+ 'value': [1, 2, 3]
238
+ })
239
+ result = group_and_aggregate(df, 'category', 'value', 'sum')
240
+ self.assertEqual(result['A'], 3)
241
+ self.assertEqual(result['B'], 3)
242
+ ''',
243
+ },
244
+ {
245
+ "instance_id": "scipy__scipy-1925",
246
+ "repo": "scipy/scipy",
247
+ "problem": "Fix signal filtering edge case",
248
+ "buggy_code": '''import numpy as np
249
+ from scipy import signal
250
+
251
+ def apply_lowpass_filter(data, cutoff, fs, order=5):
252
+ """Apply lowpass filter to data."""
253
+ # BUG: Should validate cutoff frequency
254
+ nyquist = fs / 2
255
+ normalized_cutoff = cutoff / nyquist
256
+
257
+ # BUG: Using invalid cutoff can cause filter design failure
258
+ b, a = signal.butter(order, normalized_cutoff, btype='low')
259
+ filtered = signal.filtfilt(b, a, data)
260
+
261
+ return filtered
262
+ ''',
263
+ "test_code": '''import unittest
264
+ import numpy as np
265
+ from buggy import apply_lowpass_filter
266
+
267
+ class TestSignalFiltering(unittest.TestCase):
268
+ def test_valid_filter(self):
269
+ """Test filtering with valid parameters."""
270
+ fs = 1000 # Sampling frequency
271
+ cutoff = 100 # Cutoff frequency
272
+ t = np.linspace(0, 1, fs)
273
+ data = np.sin(2 * np.pi * 50 * t) + 0.5 * np.sin(2 * np.pi * 200 * t)
274
+
275
+ result = apply_lowpass_filter(data, cutoff, fs)
276
+ self.assertEqual(len(result), len(data))
277
+ # Low frequency component should be preserved
278
+ self.assertTrue(np.abs(result[100]) > 0.5)
279
+
280
+ def test_invalid_cutoff(self):
281
+ """Test that invalid cutoff raises error."""
282
+ fs = 1000
283
+ cutoff = 2000 # Above Nyquist frequency - should fail
284
+ data = np.array([1, 2, 3, 4, 5])
285
+
286
+ with self.assertRaises(ValueError):
287
+ apply_lowpass_filter(data, cutoff, fs)
288
+ ''',
289
+ },
290
+ {
291
+ "instance_id": "sklearn__sklearn-12345",
292
+ "repo": "scikit-learn/scikit-learn",
293
+ "problem": "Fix cross-validation split",
294
+ "buggy_code": '''import numpy as np
295
+ from sklearn.model_selection import KFold
296
+
297
+ def get_cv_splits(X, n_splits=5, shuffle=True, random_state=42):
298
+ """Get cross-validation splits."""
299
+ # BUG: random_state should be used for reproducibility
300
+ kf = KFold(n_splits=n_splits, shuffle=shuffle)
301
+
302
+ splits = []
303
+ for train_idx, test_idx in kf.split(X):
304
+ splits.append((train_idx, test_idx))
305
+
306
+ return splits
307
+ ''',
308
+ "test_code": '''import unittest
309
+ import numpy as np
310
+ from buggy import get_cv_splits
311
+
312
+ class TestCVSplits(unittest.TestCase):
313
+ def test_split_count(self):
314
+ """Test that correct number of splits is generated."""
315
+ X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
316
+ splits = get_cv_splits(X, n_splits=3)
317
+ self.assertEqual(len(splits), 3)
318
+
319
+ def test_reproducibility(self):
320
+ """Test that splits are reproducible with same random_state."""
321
+ X = np.random.rand(100, 5)
322
+ splits1 = get_cv_splits(X, n_splits=5, random_state=42)
323
+ splits2 = get_cv_splits(X, n_splits=5, random_state=42)
324
+
325
+ for (train1, test1), (train2, test2) in zip(splits1, splits2):
326
+ np.testing.assert_array_equal(train1, train2)
327
+ np.testing.assert_array_equal(test1, test2)
328
+ ''',
329
+ },
330
+ {
331
+ "instance_id": "pytest__pytest-7426",
332
+ "repo": "pytest-dev/pytest",
333
+ "problem": "Fix test collection order",
334
+ "buggy_code": '''import os
335
+ import re
336
+
337
+ def collect_tests(directory, pattern='test_*.py'):
338
+ """Collect test files from directory."""
339
+ # BUG: Should sort files for consistent ordering
340
+ test_files = []
341
+
342
+ for root, dirs, files in os.walk(directory):
343
+ for file in files:
344
+ if re.match(pattern, file):
345
+ test_files.append(os.path.join(root, file))
346
+
347
+ return test_files
348
+ ''',
349
+ "test_code": '''import unittest
350
+ import os
351
+ import tempfile
352
+ from buggy import collect_tests
353
+
354
+ class TestCollection(unittest.TestCase):
355
+ def test_collect_pattern(self):
356
+ """Test that correct pattern is matched."""
357
+ with tempfile.TemporaryDirectory() as tmpdir:
358
+ # Create test files
359
+ open(os.path.join(tmpdir, 'test_a.py'), 'w').close()
360
+ open(os.path.join(tmpdir, 'test_b.py'), 'w').close()
361
+ open(os.path.join(tmpdir, 'not_a_test.py'), 'w').close()
362
+
363
+ tests = collect_tests(tmpdir, 'test_*.py')
364
+ self.assertEqual(len(tests), 2)
365
+
366
+ def test_consistent_order(self):
367
+ """Test that file order is consistent."""
368
+ with tempfile.TemporaryDirectory() as tmpdir:
369
+ for name in ['test_c.py', 'test_a.py', 'test_b.py']:
370
+ open(os.path.join(tmpdir, name), 'w').close()
371
+
372
+ tests1 = collect_tests(tmpdir)
373
+ tests2 = collect_tests(tmpdir)
374
+
375
+ self.assertEqual(tests1, tests2)
376
+ ''',
377
+ },
378
+ {
379
+ "instance_id": "transformers__transformers-12345",
380
+ "repo": "huggingface/transformers",
381
+ "problem": "Fix tokenization padding",
382
+ "buggy_code": '''from typing import List
383
+
384
+ def tokenize_and_pad(tokenizer, texts: List[str], max_length: int = 512):
385
+ """Tokenize texts and pad to max length."""
386
+ # BUG: Should handle padding correctly
387
+ encoded = tokenizer(
388
+ texts,
389
+ padding=True, # This pads to longest in batch, not max_length
390
+ truncation=True,
391
+ max_length=max_length,
392
+ return_tensors='pt'
393
+ )
394
+
395
+ return encoded
396
+ ''',
397
+ "test_code": '''import unittest
398
+ from buggy import tokenize_and_pad
399
+
400
+ class MockTokenizer:
401
+ def __call__(self, texts, padding=True, truncation=True, max_length=512, return_tensors=None):
402
+ # Simplified mock
403
+ return {
404
+ 'input_ids': [[1, 2, 3]] if isinstance(texts, list) else [1, 2, 3],
405
+ 'attention_mask': [[1, 1, 1]] if isinstance(texts, list) else [1, 1, 1]
406
+ }
407
+
408
+ class TestTokenization(unittest.TestCase):
409
+ def test_single_text(self):
410
+ """Test tokenizing single text."""
411
+ tokenizer = MockTokenizer()
412
+ result = tokenize_and_pad(tokenizer, ["hello world"])
413
+ self.assertIn('input_ids', result)
414
+
415
+ def test_max_length_respected(self):
416
+ """Test that max_length is respected."""
417
+ tokenizer = MockTokenizer()
418
+ # Should not raise even with long text
419
+ result = tokenize_and_pad(tokenizer, ["short"], max_length=10)
420
+ self.assertIn('input_ids', result)
421
+ ''',
422
+ },
423
+ ]
424
+
425
+ # Easy, Medium, Hard difficulty assignments
426
+ DIFFICULTY_TASKS = {
427
+ "easy": SWE_BENCH_PROBLEMS[:3],
428
+ "medium": SWE_BENCH_PROBLEMS[3:6],
429
+ "hard": SWE_BENCH_PROBLEMS[6:],
430
+ }
431
+
432
+
433
+ def generate_tasks(output_dir: Path, count_per_difficulty: int = 3):
434
+ """Generate SWE-bench style tasks."""
435
+ output_dir = Path(output_dir)
436
+ output_dir.mkdir(parents=True, exist_ok=True)
437
+
438
+ total_created = 0
439
+
440
+ for difficulty, problems in DIFFICULTY_TASKS.items():
441
+ for i, problem in enumerate(problems[:count_per_difficulty]):
442
+ instance_id = f"{problem['instance_id']}_{difficulty}_{i}"
443
+ instance_dir = output_dir / instance_id
444
+ instance_dir.mkdir(parents=True, exist_ok=True)
445
+
446
+ # Write buggy.py
447
+ buggy_file = instance_dir / "buggy.py"
448
+ buggy_file.write_text(problem["buggy_code"], encoding="utf-8")
449
+
450
+ # Write test.py
451
+ test_file = instance_dir / "test.py"
452
+ test_file.write_text(problem["test_code"], encoding="utf-8")
453
+
454
+ # Write metadata.json
455
+ metadata = {
456
+ "instance_id": instance_id,
457
+ "repo": problem["repo"],
458
+ "problem_statement": problem["problem"],
459
+ "difficulty": difficulty,
460
+ }
461
+ metadata_file = instance_dir / "metadata.json"
462
+ metadata_file.write_text(json.dumps(metadata, indent=2), encoding="utf-8")
463
+
464
+ total_created += 1
465
+
466
+ print(f"Created {total_created} tasks in {output_dir}")
467
+ print(f"Set environment variable: SWEBENCH_TASKS_ROOT={output_dir.absolute()}")
468
+ print(f"Or run with: TASK_SOURCE=swebench python inference.py")
469
+
470
+
471
+ def main():
472
+ parser = argparse.ArgumentParser(description="Generate SWE-bench style tasks")
473
+ parser.add_argument(
474
+ "--count",
475
+ type=int,
476
+ default=3,
477
+ help="Number of tasks per difficulty (default: 3)"
478
+ )
479
+ parser.add_argument(
480
+ "--output-dir",
481
+ type=str,
482
+ default=None,
483
+ help="Output directory (default: dataset/swebench_lite_tasks)"
484
+ )
485
+
486
+ args = parser.parse_args()
487
+
488
+ if args.output_dir:
489
+ output_dir = Path(args.output_dir)
490
+ else:
491
+ script_dir = Path(__file__).parent
492
+ output_dir = script_dir / "swebench_lite_tasks"
493
+
494
+ generate_tasks(output_dir, args.count)
495
+
496
+
497
+ if __name__ == "__main__":
498
+ main()
dataset/loader.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load static, competition-approved tasks."""
2
+
3
+ import os
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional
7
+
8
+ # Get the dataset root (same folder as this file)
9
+ DATASET_ROOT = Path(__file__).parent
10
+
11
+ # Hardcoded competition tasks: Easy Medium Hard
12
+ STATIC_TASKS = {
13
+ "easy": {
14
+ "problem_id": "problem_1",
15
+ "difficulty": "easy",
16
+ "description": "String reversal with space normalization",
17
+ },
18
+ "medium": {
19
+ "problem_id": "problem_10",
20
+ "difficulty": "medium",
21
+ "description": "Matrix 90 clockwise rotation",
22
+ },
23
+ "hard": {
24
+ "problem_id": "problem_13",
25
+ "difficulty": "hard",
26
+ "description": "LRU cache with correct eviction policy",
27
+ },
28
+ }
29
+
30
+
31
+ def load_problem(problem_id: str) -> Dict[str, any]:
32
+ """
33
+ Load a single problem from disk.
34
+
35
+ Args:
36
+ problem_id: e.g., "problem_1", "problem_10", "problem_13"
37
+
38
+ Returns:
39
+ {
40
+ "code": str, # buggy.py content
41
+ "tests": str, # test.py path (relative to problem folder)
42
+ "metadata": dict, # metadata.json
43
+ "problem_dir": str, # absolute path to problem folder
44
+ }
45
+ """
46
+ problem_dir = DATASET_ROOT / problem_id
47
+
48
+ if not problem_dir.exists():
49
+ raise FileNotFoundError(f"Problem directory not found: {problem_dir}")
50
+
51
+ # Load buggy code
52
+ buggy_file = problem_dir / "buggy.py"
53
+ code = buggy_file.read_text(encoding="utf-8")
54
+
55
+ # Load metadata
56
+ metadata_file = problem_dir / "metadata.json"
57
+ metadata = json.loads(metadata_file.read_text(encoding="utf-8"))
58
+
59
+ # Test file path (relative to problem root)
60
+ test_path = str(problem_dir / "test.py")
61
+
62
+ return {
63
+ "code": code,
64
+ "tests": test_path,
65
+ "metadata": metadata,
66
+ "problem_dir": str(problem_dir),
67
+ "problem_id": problem_id,
68
+ }
69
+
70
+
71
+ def get_hardcoded_task(difficulty: str) -> Dict[str, any]:
72
+ """
73
+ Get one of the three static competition tasks.
74
+
75
+ Args:
76
+ difficulty: "easy" | "medium" | "hard"
77
+
78
+ Returns:
79
+ Task dict with code, tests, metadata
80
+
81
+ Raises:
82
+ ValueError: if difficulty is not one of the three approved values
83
+ """
84
+ if difficulty not in STATIC_TASKS:
85
+ raise ValueError(
86
+ f"Invalid difficulty '{difficulty}'. "
87
+ f"Must be one of: {list(STATIC_TASKS.keys())}"
88
+ )
89
+
90
+ task_info = STATIC_TASKS[difficulty]
91
+ problem_id = task_info["problem_id"]
92
+
93
+ return load_problem(problem_id)
94
+
95
+
96
+ def get_random_tasks():
97
+ """
98
+ DEPRECATED: Use get_hardcoded_task() instead.
99
+ Kept for backward compatibility.
100
+ """
101
+ import warnings
102
+ warnings.warn(
103
+ "get_random_tasks() is deprecated. Use get_hardcoded_task('easy'|'medium'|'hard')",
104
+ DeprecationWarning,
105
+ stacklevel=2
106
+ )
107
+ # Return a default (easy)
108
+ return get_hardcoded_task("easy")
109
+
110
+
111
+
dataset/prepare_swebench.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to download and materialize SWE-bench Lite tasks.
3
+
4
+ This script:
5
+ 1. Downloads SWE-bench Lite dataset from HuggingFace
6
+ 2. Extracts the buggy code and creates test files
7
+ 3. Organizes them into the expected directory structure
8
+
9
+ Usage:
10
+ python -m dataset.prepare_swebench [--max-tasks N] [--difficulty easy|medium|hard|all]
11
+ """
12
+
13
+ import argparse
14
+ import os
15
+ import sys
16
+ from pathlib import Path
17
+
18
+ # Add parent to path for imports
19
+ sys.path.insert(0, str(Path(__file__).parent.parent))
20
+
21
+ from datasets import load_dataset
22
+
23
+
24
+ def get_problem_statement(row):
25
+ """Extract problem statement from row."""
26
+ return row.get("problem_statement", "")
27
+
28
+
29
+ def get_patch(row):
30
+ """Extract the patch/fix from row."""
31
+ return row.get("patch", "")
32
+
33
+
34
+ def get_instance_id(row):
35
+ """Get instance ID from row."""
36
+ return row.get("instance_id", "")
37
+
38
+
39
+ def create_buggy_file(instance_dir: Path, row):
40
+ """
41
+ Create buggy.py from the base commit and instance.
42
+
43
+ The SWE-bench dataset provides the full repository at base_commit.
44
+ We need to extract just the relevant file that has the bug.
45
+ """
46
+ # For SWE-bench, the "buggy" version is actually the version BEFORE the patch
47
+ # We need to get the file content from the base commit
48
+ # This is complex as it requires cloning the repo at a specific commit
49
+
50
+ # For simplicity, we'll use a different approach:
51
+ # The problem_statement describes the bug, and we can create a simplified
52
+ # buggy version based on that description
53
+
54
+ instance_id = get_instance_id(row)
55
+ problem_stmt = get_problem_statement(row)
56
+
57
+ # Try to extract the file from the created files in the instance
58
+ # SWE-bench provides 'repo' and we need to find the relevant file
59
+ created_files = row.get("created_files", [])
60
+
61
+ if not created_files:
62
+ # Fallback: create a placeholder
63
+ buggy_code = f'''# Buggy code for {instance_id}
64
+ # Problem: {problem_stmt[:200]}...
65
+
66
+ def solution():
67
+ """Placeholder solution - needs to be fixed."""
68
+ pass
69
+ '''
70
+ else:
71
+ # For now, create a simple placeholder
72
+ # In a full implementation, we'd clone the repo at base_commit
73
+ file_path = created_files[0] if created_files else "solution.py"
74
+ buggy_code = f'''# Buggy code for {instance_id}
75
+ # File: {file_path}
76
+ # Problem: {problem_stmt[:200]}...
77
+
78
+ def solution():
79
+ """Placeholder solution - needs to be fixed."""
80
+ pass
81
+ '''
82
+
83
+ buggy_file = instance_dir / "buggy.py"
84
+ buggy_file.write_text(buggy_code, encoding="utf-8")
85
+ return buggy_file
86
+
87
+
88
+ def create_test_file(instance_dir: Path, row):
89
+ """
90
+ Create test.py based on the problem statement.
91
+
92
+ For SWE-bench, tests are typically derived from the issue description.
93
+ We'll create a simple test that checks if the solution works.
94
+ """
95
+ instance_id = get_instance_id(row)
96
+ problem_stmt = get_problem_statement(row)
97
+
98
+ # Create a simple test file
99
+ # In practice, SWE-bench has a test.json file with test cases
100
+ test_cases = row.get("test_cases", [])
101
+
102
+ if test_cases:
103
+ # Create tests from provided test cases
104
+ test_code = "import unittest\\n\\n"
105
+ for i, tc in enumerate(test_cases):
106
+ input_str = tc.get("input", "")
107
+ output_str = tc.get("output", "")
108
+ test_code += f'''class TestSolution(unittest.TestCase):
109
+ def test_case_{i+1}(self):
110
+ # Input: {input_str}
111
+ # Expected: {output_str}
112
+ pass # TODO: Add actual test
113
+ '''
114
+ else:
115
+ # Create a basic test based on problem statement
116
+ test_code = f'''"""Test file for {instance_id}"""
117
+
118
+ import unittest
119
+ from buggy import solution
120
+
121
+
122
+ class TestSolution(unittest.TestCase):
123
+ def test_basic(self):
124
+ """Test based on problem statement."""
125
+ # Problem: {problem_stmt[:300]}...
126
+ result = solution()
127
+ self.assertIsNotNone(result)
128
+
129
+
130
+ if __name__ == "__main__":
131
+ unittest.main()
132
+ '''
133
+
134
+ test_file = instance_dir / "test.py"
135
+ test_file.write_text(test_code, encoding="utf-8")
136
+ return test_file
137
+
138
+
139
+ def create_metadata_file(instance_dir: Path, row):
140
+ """Create metadata.json with instance info."""
141
+ import json
142
+
143
+ metadata = {
144
+ "instance_id": get_instance_id(row),
145
+ "repo": row.get("repo", ""),
146
+ "base_commit": row.get("base_commit", ""),
147
+ "problem_statement": get_problem_statement(row),
148
+ "patch": get_patch(row),
149
+ "difficulty": "medium", # Will be set based on index
150
+ }
151
+
152
+ metadata_file = instance_dir / "metadata.json"
153
+ metadata_file.write_text(json.dumps(metadata, indent=2), encoding="utf-8")
154
+ return metadata_file
155
+
156
+
157
+ def prepare_swebench_tasks(
158
+ output_dir: Path,
159
+ max_tasks: int = 30,
160
+ difficulty: str = "all"
161
+ ):
162
+ """
163
+ Download and prepare SWE-bench Lite tasks.
164
+
165
+ Args:
166
+ output_dir: Directory to save tasks
167
+ max_tasks: Maximum number of tasks to download
168
+ difficulty: "easy", "medium", "hard", or "all"
169
+ """
170
+ print(f"Loading SWE-bench Lite dataset...")
171
+
172
+ try:
173
+ ds = load_dataset("princeton-nlp/SWE-bench_Lite", split="test")
174
+ except Exception as e:
175
+ print(f"Error loading dataset: {e}")
176
+ print("Trying alternative dataset name...")
177
+ ds = load_dataset("swe-bench/swe-bench-lite", split="test")
178
+
179
+ print(f"Loaded {len(ds)} tasks")
180
+
181
+ # Calculate difficulty bounds
182
+ total = len(ds)
183
+ one_third = max(total // 3, 1)
184
+ two_third = max((2 * total) // 3, one_third + 1)
185
+
186
+ difficulty_ranges = {
187
+ "easy": (0, one_third),
188
+ "medium": (one_third, two_third),
189
+ "hard": (two_third, total),
190
+ }
191
+
192
+ # Determine which tasks to download
193
+ if difficulty == "all":
194
+ ranges = list(difficulty_ranges.values())
195
+ indices = []
196
+ for start, end in ranges:
197
+ indices.extend(range(start, min(end, start + max_tasks // 3)))
198
+ else:
199
+ start, end = difficulty_ranges.get(difficulty, (0, total))
200
+ indices = list(range(start, min(end, max_tasks)))
201
+
202
+ # Create output directory
203
+ output_dir = Path(output_dir)
204
+ output_dir.mkdir(parents=True, exist_ok=True)
205
+
206
+ print(f"Preparing {len(indices)} tasks...")
207
+
208
+ success_count = 0
209
+ for i, idx in enumerate(indices):
210
+ try:
211
+ row = ds[idx]
212
+ instance_id = get_instance_id(row)
213
+
214
+ # Create instance directory
215
+ instance_dir = output_dir / instance_id
216
+ instance_dir.mkdir(parents=True, exist_ok=True)
217
+
218
+ # Create files
219
+ create_buggy_file(instance_dir, row)
220
+ create_test_file(instance_dir, row)
221
+ create_metadata_file(instance_dir, row)
222
+
223
+ success_count += 1
224
+ if (i + 1) % 10 == 0:
225
+ print(f" Processed {i + 1}/{len(indices)} tasks...")
226
+
227
+ except Exception as e:
228
+ print(f" Warning: Failed to process task {idx}: {e}")
229
+ continue
230
+
231
+ print(f"\nDone! Prepared {success_count}/{len(indices)} tasks in {output_dir}")
232
+ print(f"Set SWEBENCH_TASKS_ROOT={output_dir.absolute()} to use these tasks.")
233
+
234
+
235
+ def main():
236
+ parser = argparse.ArgumentParser(description="Prepare SWE-bench Lite tasks")
237
+ parser.add_argument(
238
+ "--max-tasks",
239
+ type=int,
240
+ default=30,
241
+ help="Maximum number of tasks to download (default: 30)"
242
+ )
243
+ parser.add_argument(
244
+ "--difficulty",
245
+ type=str,
246
+ default="all",
247
+ choices=["easy", "medium", "hard", "all"],
248
+ help="Difficulty level to download (default: all)"
249
+ )
250
+ parser.add_argument(
251
+ "--output-dir",
252
+ type=str,
253
+ default=None,
254
+ help="Output directory (default: dataset/swebench_lite_tasks)"
255
+ )
256
+
257
+ args = parser.parse_args()
258
+
259
+ # Determine output directory
260
+ if args.output_dir:
261
+ output_dir = Path(args.output_dir)
262
+ else:
263
+ script_dir = Path(__file__).parent
264
+ output_dir = script_dir / "swebench_lite_tasks"
265
+
266
+ prepare_swebench_tasks(
267
+ output_dir=output_dir,
268
+ max_tasks=args.max_tasks,
269
+ difficulty=args.difficulty
270
+ )
271
+
272
+
273
+ if __name__ == "__main__":
274
+ main()
dataset/problem_1/buggy.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ def safe_divide(a: float, b: float) -> float:
2
+ """Divide a by b; only return inf for division by zero."""
3
+ try:
4
+ return a / b
5
+ except Exception:
6
+ # BUG: catches unrelated errors too broadly.
7
+ return float("inf")
dataset/problem_1/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "easy",
3
+ "bug_type": "exception-handling",
4
+ "expected_steps": 1
5
+ }
dataset/problem_1/test.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from dataset.problem_1.buggy import safe_divide
3
+
4
+
5
+ class TestSafeDivide(unittest.TestCase):
6
+ def test_normal(self):
7
+ self.assertEqual(safe_divide(8, 2), 4)
8
+
9
+ def test_zero_division(self):
10
+ self.assertEqual(safe_divide(1, 0), float("inf"))
11
+
12
+ def test_type_error_should_raise(self):
13
+ with self.assertRaises(TypeError):
14
+ safe_divide("1", 1)
15
+
16
+
17
+ if __name__ == "__main__":
18
+ unittest.main()
dataset/problem_10/buggy.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from dataset.problem_10.helpers import transpose
2
+
3
+
4
+ def rotate_90_clockwise(matrix: list[list[int]]) -> list[list[int]]:
5
+ """Rotate matrix 90 degrees clockwise."""
6
+ t = transpose(matrix)
7
+ # BUG: this is counter-clockwise.
8
+ return t[::-1]
dataset/problem_10/helpers.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def transpose(matrix: list[list[int]]) -> list[list[int]]:
2
+ return [list(row) for row in zip(*matrix)]
dataset/problem_10/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "medium",
3
+ "bug_type": "matrix-transformation",
4
+ "expected_steps": 1
5
+ }
dataset/problem_10/test.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from dataset.problem_10.buggy import rotate_90_clockwise
3
+
4
+
5
+ class TestRotateMatrix(unittest.TestCase):
6
+ def test_2x2(self):
7
+ matrix = [[1, 2], [3, 4]]
8
+ self.assertEqual(rotate_90_clockwise(matrix), [[3, 1], [4, 2]])
9
+
10
+
11
+ if __name__ == "__main__":
12
+ unittest.main()
dataset/problem_11/buggy.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def binary_search(nums: list[int], target: int) -> int:
2
+ """Return index of target, or -1 if not found."""
3
+ left, right = 0, len(nums) - 1
4
+
5
+ while left < right:
6
+ mid = (left + right) // 2
7
+ if nums[mid] == target:
8
+ return mid
9
+ if nums[mid] < target:
10
+ left = mid + 1
11
+ else:
12
+ right = mid - 1
13
+
14
+ return -1
dataset/problem_11/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "medium",
3
+ "bug_type": "boundary-condition",
4
+ "expected_steps": 2
5
+ }
dataset/problem_11/test.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from dataset.problem_11.buggy import binary_search
3
+
4
+
5
+ class TestBinarySearch(unittest.TestCase):
6
+ def test_found_middle(self):
7
+ self.assertEqual(binary_search([1, 3, 5, 7], 5), 2)
8
+
9
+ def test_found_last(self):
10
+ self.assertEqual(binary_search([1, 3, 5, 7], 7), 3)
11
+
12
+ def test_not_found(self):
13
+ self.assertEqual(binary_search([1, 3, 5, 7], 4), -1)
14
+
15
+
16
+ if __name__ == "__main__":
17
+ unittest.main()
dataset/problem_12/buggy.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def parse_pairs(raw: str) -> dict[str, int]:
2
+ """Parse strings like 'a=1,b=2' into a dict."""
3
+ result = {}
4
+ if not raw:
5
+ return result
6
+
7
+ for segment in raw.split(","):
8
+ key, value = segment.split("=")
9
+ # BUG: does not strip whitespace around keys/values.
10
+ result[key] = int(value)
11
+ return result
dataset/problem_12/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "easy",
3
+ "bug_type": "string-normalization",
4
+ "expected_steps": 2
5
+ }
dataset/problem_12/test.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from dataset.problem_12.buggy import parse_pairs
3
+
4
+
5
+ class TestParsePairs(unittest.TestCase):
6
+ def test_simple(self):
7
+ self.assertEqual(parse_pairs("a=1,b=2"), {"a": 1, "b": 2})
8
+
9
+ def test_spaces(self):
10
+ self.assertEqual(parse_pairs("x = 10, y = 20"), {"x": 10, "y": 20})
11
+
12
+
13
+ if __name__ == "__main__":
14
+ unittest.main()
dataset/problem_13/buggy.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataset.problem_13.cache import LRUCache
2
+
3
+
4
+ def run_ops() -> tuple[int, int]:
5
+ cache = LRUCache(2)
6
+ cache.put("a", 1)
7
+ cache.put("b", 2)
8
+ _ = cache.get("a")
9
+ cache.put("c", 3)
10
+ return cache.get("a"), cache.get("b")
dataset/problem_13/cache.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+
4
+ class LRUCache:
5
+ def __init__(self, capacity: int):
6
+ self.capacity = capacity
7
+ self.store: OrderedDict[str, int] = OrderedDict()
8
+
9
+ def get(self, key: str) -> int:
10
+ if key not in self.store:
11
+ return -1
12
+ # BUG: does not refresh recency when key is accessed.
13
+ return self.store[key]
14
+
15
+ def put(self, key: str, value: int) -> None:
16
+ if key in self.store:
17
+ self.store.pop(key)
18
+ self.store[key] = value
19
+ if len(self.store) > self.capacity:
20
+ self.store.popitem(last=False)
dataset/problem_13/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "hard",
3
+ "bug_type": "state-logic",
4
+ "expected_steps": 2
5
+ }
dataset/problem_13/test.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from dataset.problem_13.buggy import run_ops
3
+
4
+
5
+ class TestLRU(unittest.TestCase):
6
+ def test_recency_update_on_get(self):
7
+ a, b = run_ops()
8
+ self.assertEqual(a, 1)
9
+ self.assertEqual(b, -1)
10
+
11
+
12
+ if __name__ == "__main__":
13
+ unittest.main()
dataset/problem_14/buggy.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def fibonacci_recursive(n: int) -> int:
2
+ """Return nth Fibonacci number."""
3
+ # BUG: wrong base case for n == 0.
4
+ if n <= 1:
5
+ return 1
6
+ return fibonacci_recursive(n - 1) + fibonacci_recursive(n - 2)
dataset/problem_14/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "easy",
3
+ "bug_type": "recursion-base-case",
4
+ "expected_steps": 2
5
+ }
dataset/problem_14/test.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from dataset.problem_14.buggy import fibonacci_recursive
3
+
4
+
5
+ class TestFibonacciRecursive(unittest.TestCase):
6
+ def test_base_cases(self):
7
+ self.assertEqual(fibonacci_recursive(0), 0)
8
+ self.assertEqual(fibonacci_recursive(1), 1)
9
+
10
+ def test_n5(self):
11
+ self.assertEqual(fibonacci_recursive(5), 5)
12
+
13
+
14
+ if __name__ == "__main__":
15
+ unittest.main()
dataset/problem_15/buggy.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ def has_overlap(a: tuple[int, int], b: tuple[int, int]) -> bool:
2
+ """Check if closed intervals [a0, a1] and [b0, b1] overlap."""
3
+ # BUG: uses strict inequalities, missing touching-boundary overlap.
4
+ return a[0] < b[1] and b[0] < a[1]
dataset/problem_15/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "medium",
3
+ "bug_type": "boundary-condition",
4
+ "expected_steps": 1
5
+ }
dataset/problem_15/test.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from dataset.problem_15.buggy import has_overlap
3
+
4
+
5
+ class TestIntervalOverlap(unittest.TestCase):
6
+ def test_overlapping(self):
7
+ self.assertTrue(has_overlap((1, 5), (4, 9)))
8
+
9
+ def test_touching_endpoints(self):
10
+ self.assertTrue(has_overlap((1, 3), (3, 7)))
11
+
12
+
13
+ if __name__ == "__main__":
14
+ unittest.main()
dataset/problem_16/buggy.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataset.problem_16.helpers import normalize_scores
2
+
3
+
4
+ def top_label(scores: dict[str, float]) -> str:
5
+ """Return label with highest normalized probability."""
6
+ labels = list(scores.keys())
7
+ probs = normalize_scores(list(scores.values()))
8
+ # BUG: chooses min instead of max.
9
+ idx = min(range(len(probs)), key=lambda i: probs[i])
10
+ return labels[idx]
dataset/problem_16/helpers.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ def normalize_scores(scores: list[float]) -> list[float]:
2
+ total = sum(scores)
3
+ return [s / total for s in scores]
dataset/problem_16/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "easy",
3
+ "bug_type": "logic-error",
4
+ "expected_steps": 1
5
+ }
dataset/problem_16/test.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from dataset.problem_16.buggy import top_label
3
+
4
+
5
+ class TestTopLabel(unittest.TestCase):
6
+ def test_select_highest(self):
7
+ scores = {"cat": 0.2, "dog": 0.7, "bird": 0.1}
8
+ self.assertEqual(top_label(scores), "dog")
9
+
10
+
11
+ if __name__ == "__main__":
12
+ unittest.main()
dataset/problem_17/buggy.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def dedupe_preserve_order(items: list[int]) -> list[int]:
2
+ """Remove duplicates while preserving first occurrence order."""
3
+ seen = set()
4
+ out = []
5
+ for item in items:
6
+ # BUG: keeps last occurrence logic effectively by replacing list.
7
+ if item in seen:
8
+ out = [x for x in out if x != item]
9
+ seen.add(item)
10
+ out.append(item)
11
+ return out
dataset/problem_17/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "medium",
3
+ "bug_type": "logic-error",
4
+ "expected_steps": 2
5
+ }
dataset/problem_17/test.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from dataset.problem_17.buggy import dedupe_preserve_order
3
+
4
+
5
+ class TestDedupe(unittest.TestCase):
6
+ def test_order(self):
7
+ self.assertEqual(dedupe_preserve_order([1, 2, 1, 3, 2]), [1, 2, 3])
8
+
9
+
10
+ if __name__ == "__main__":
11
+ unittest.main()
dataset/problem_18/buggy.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataset.problem_18.math_utils import clamp
2
+
3
+
4
+ def moving_average(nums: list[int], window: int) -> list[float]:
5
+ """Simple moving average over a fixed window."""
6
+ if window <= 0:
7
+ raise ValueError("window must be positive")
8
+
9
+ window = clamp(window, 1, len(nums))
10
+ out = []
11
+ # BUG: end index is off-by-one; misses final valid window.
12
+ for i in range(0, len(nums) - window):
13
+ out.append(sum(nums[i : i + window]) / window)
14
+ return out
dataset/problem_18/math_utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def clamp(value: int, low: int, high: int) -> int:
2
+ if value < low:
3
+ return low
4
+ if value > high:
5
+ return high
6
+ return value
dataset/problem_18/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "medium",
3
+ "bug_type": "off-by-one",
4
+ "expected_steps": 1
5
+ }
dataset/problem_18/test.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from dataset.problem_18.buggy import moving_average
3
+
4
+
5
+ class TestMovingAverage(unittest.TestCase):
6
+ def test_window_3(self):
7
+ self.assertEqual(moving_average([1, 2, 3, 4, 5], 3), [2.0, 3.0, 4.0])
8
+
9
+ def test_window_larger_than_data(self):
10
+ self.assertEqual(moving_average([2, 4], 5), [3.0])
11
+
12
+
13
+ if __name__ == "__main__":
14
+ unittest.main()
dataset/problem_19/buggy.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def calculate_employee_bonus(employees: list[dict], metrics: dict) -> list[dict]:
2
+ """
3
+ Calculate employee bonuses based on their base salary, performance rating,
4
+ and company-wide metrics.
5
+
6
+ employees: list of dicts with 'id', 'role', 'base_salary', 'rating' (1-5)
7
+ metrics: dict with 'company_multiplier' and 'department_multipliers'
8
+
9
+ Returns a list of dicts with 'id' and 'bonus'.
10
+ """
11
+ results = []
12
+
13
+ for emp in employees:
14
+ # BUG 1: Division by zero risk if rating is 0 or missing, and type mismatch if salary is string
15
+ base = emp.get('base_salary', 0)
16
+ rating = emp.get('rating', 1)
17
+
18
+ # BUG 2: Incorrect logic for role based multiplier, using assignment instead of lookup
19
+ role_mult = metrics.get('department_multipliers', {})[emp.get('role')] # will raise KeyError if role not found
20
+
21
+ # Calculate base bonus
22
+ if rating > 3:
23
+ base_bonus = base * 0.1
24
+ elif rating == 3:
25
+ base_bonus = base * 0.05
26
+ else:
27
+ base_bonus = 0
28
+
29
+ # BUG 3: Does not apply company multiplier correctly to the total
30
+ total_bonus = base_bonus * role_mult + metrics.get('company_multiplier', 1)
31
+
32
+ # BUG 4: mutating original dict instead of creating new one
33
+ emp['bonus'] = total_bonus
34
+ results.append(emp)
35
+
36
+ return results
dataset/problem_19/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "hard",
3
+ "bug_type": "multiple",
4
+ "expected_steps": 4
5
+ }
dataset/problem_19/test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from dataset.problem_19.buggy import calculate_employee_bonus
3
+
4
+ def test_calculate_employee_bonus():
5
+ employees = [
6
+ {'id': 1, 'role': 'engineering', 'base_salary': 100000, 'rating': 4},
7
+ {'id': 2, 'role': 'sales', 'base_salary': '80000', 'rating': 3},
8
+ {'id': 3, 'role': 'hr', 'base_salary': 60000, 'rating': 2},
9
+ {'id': 4, 'role': 'unknown', 'base_salary': 50000, 'rating': 5}
10
+ ]
11
+
12
+ metrics = {
13
+ 'company_multiplier': 1.2,
14
+ 'department_multipliers': {
15
+ 'engineering': 1.5,
16
+ 'sales': 1.2,
17
+ 'hr': 1.0
18
+ }
19
+ }
20
+
21
+ # Original dicts should not be modified
22
+ orig_employees = [dict(e) for e in employees]
23
+
24
+ results = calculate_employee_bonus(employees, metrics)
25
+
26
+ # Check if original was modified
27
+ assert employees == orig_employees, "Original list was mutated"
28
+
29
+ # Check results format
30
+ assert len(results) == 4
31
+ for r in results:
32
+ assert 'id' in r
33
+ assert 'bonus' in r
34
+ assert 'role' not in r # Should only contain id and bonus
35
+
36
+ # Check values
37
+ # Emp 1: 100000 * 0.1 * 1.5 * 1.2 = 18000
38
+ assert results[0]['bonus'] == 18000
39
+
40
+ # Emp 2: 80000 * 0.05 * 1.2 * 1.2 = 5760 (string salary handling)
41
+ assert results[1]['bonus'] == 5760
42
+
43
+ # Emp 3: 0 bonus due to rating 2
44
+ assert results[2]['bonus'] == 0
45
+
46
+ # Emp 4: unknown role falls back to 1.0 multiplier
47
+ # 50000 * 0.1 * 1.0 * 1.2 = 6000
48
+ assert results[3]['bonus'] == 6000
dataset/problem_2/buggy.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def binary_search(nums: list[int], target: int) -> int:
2
+ """Return index of target, or -1 if not found."""
3
+ left, right = 0, len(nums) - 1
4
+
5
+ while left < right:
6
+ mid = (left + right) // 2
7
+ if nums[mid] == target:
8
+ return mid
9
+ if nums[mid] < target:
10
+ left = mid + 1
11
+ else:
12
+ right = mid - 1
13
+
14
+ return -1
dataset/problem_2/metadata.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "difficulty": "medium",
3
+ "bug_type": "boundary-condition",
4
+ "expected_steps": 2
5
+ }