PRANAV05092003 commited on
Commit
8d66fec
·
1 Parent(s): 900e1f4

Updated structure and fixed module import issue

Browse files
ACRE_FINAL/.gitignore DELETED
@@ -1,26 +0,0 @@
1
- __pycache__/
2
- *.pyc
3
- *.pyo
4
- *.pyd
5
- .Python
6
- *.egg-info/
7
- dist/
8
- build/
9
- .env
10
- .venv
11
- venv/
12
- *.zip
13
- acre_agent.zip
14
- *.log
15
- .DS_Store
16
- .deps/
17
- libs/
18
- numpy.libs/
19
- *.dll
20
- *.so
21
- *.dylib
22
- env/
23
- ENV/
24
- .cache/
25
- .huggingface/
26
- Thumbs.db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/Dockerfile DELETED
@@ -1,23 +0,0 @@
1
- FROM python:3.11-slim
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y --no-install-recommends \
6
- build-essential \
7
- && rm -rf /var/lib/apt/lists/*
8
-
9
- COPY requirements.txt .
10
- RUN pip install --no-cache-dir -r requirements.txt
11
-
12
- COPY . .
13
-
14
- ENV API_BASE_URL=https://api.openai.com/v1
15
- ENV MODEL_NAME=gpt-4o-mini
16
- ENV PORT=7860
17
-
18
- EXPOSE 7860
19
-
20
- HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
21
- CMD python -c "import requests; requests.get('http://localhost:7860/').raise_for_status()"
22
-
23
- CMD ["python", "server.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/README.md DELETED
@@ -1,174 +0,0 @@
1
- ---
2
- title: ACRE - Autonomous Code Refactoring Environment
3
- colorFrom: blue
4
- colorTo: green
5
- sdk: docker
6
- app_port: 7860
7
- pinned: false
8
- license: mit
9
- tags:
10
- - openenv
11
- ---
12
-
13
- # ACRE - Autonomous Code Refactoring Environment
14
-
15
- ACRE is an OpenEnv-compatible environment for autonomous Python code refactoring. An agent receives real code-cleanup tasks and must improve the code through AST-based transformations while receiving dense reward feedback for correctness, simplification, and performance.
16
-
17
- ## Environment Overview and Motivation
18
-
19
- This project simulates a realistic developer workflow: cleaning up messy Python code, removing dead logic, simplifying loops, and inlining trivial helpers. The canonical OpenEnv wrapper lives in `openenv_interface.py`, while the original Gymnasium-compatible environment remains available for RL training and demos.
20
-
21
- ## Definitions of Action and Observation Spaces
22
-
23
- ### Action Space - Discrete(5)
24
-
25
- | Action | Name | Description |
26
- |---|---|---|
27
- | 0 | rename_variable | Rename generic variables like `x`, `tmp`, and `i` |
28
- | 1 | remove_dead_code | Remove unreachable statements, `if False` branches, and unused assignments |
29
- | 2 | simplify_loop | Convert append-loops into list comprehensions |
30
- | 3 | optimize_condition | Simplify `not not x`, `if True`, `if False`, and boolean comparisons |
31
- | 4 | inline_function | Inline simple single-return module-level functions |
32
-
33
- ### Observation Space - Box(4,)
34
-
35
- The environment tracks:
36
-
37
- - `code_length`
38
- - `complexity_score`
39
- - `runtime_s`
40
- - `error_flag`
41
-
42
- ### Typed OpenEnv Models
43
-
44
- The submission-facing interface uses Pydantic models in `models.py`:
45
-
46
- - `ObservationModel`
47
- - `ActionModel`
48
- - `RewardModel`
49
- - `StateResponse`
50
-
51
- The canonical interface is:
52
-
53
- ```python
54
- observation = env.reset(...)
55
- observation, reward, done, info = env.step(action)
56
- state = env.state()
57
- ```
58
-
59
- ## Task Descriptions with Expected Difficulty Levels
60
-
61
- | Task ID | Difficulty | Objective |
62
- |---|---|---|
63
- | `rename_variables` | Easy | Remove generic variable names from the snippet |
64
- | `remove_dead_code` | Medium | Eliminate dead branches, unreachable code, and unused assignments |
65
- | `full_refactor` | Hard | Combine renaming, dead-code removal, loop simplification, condition optimization, and inlining |
66
-
67
- Each task includes a deterministic AST-based grader returning a score in `[0.0, 1.0]`.
68
-
69
- ## Reward Design
70
-
71
- Rewards are shaped throughout the trajectory instead of only at the end.
72
-
73
- - Success reward for syntactically valid, executable output
74
- - Complexity reward when control-flow complexity decreases
75
- - Performance reward when runtime improves
76
- - Error penalty for invalid or failing code
77
- - No-change penalty to discourage loops and unproductive actions
78
-
79
- Raw reward range is `[-32, 20]`, normalized to `[0.0, 1.0]` with `(raw + 32) / 52`.
80
-
81
- ## HTTP API
82
-
83
- | Method | Path | Purpose |
84
- |---|---|---|
85
- | GET | `/` | Health check |
86
- | GET | `/health` | Compatibility health check |
87
- | POST | `/reset` | Reset environment and return typed observation/state |
88
- | POST | `/step` | Apply one action and return typed observation/reward/done |
89
- | GET | `/state` | Return the current typed state |
90
- | GET | `/tasks` | List available tasks |
91
- | POST | `/tasks/{task_id}/grade` | Grade submitted code |
92
-
93
- ## Setup and Usage Instructions
94
-
95
- ### Local setup
96
-
97
- ```bash
98
- pip install -r requirements.txt
99
- python server.py
100
- ```
101
-
102
- ### Baseline inference
103
-
104
- Set environment variables before running:
105
-
106
- ```bash
107
- export API_BASE_URL=https://api.openai.com/v1
108
- export MODEL_NAME=gpt-4o-mini
109
- export HF_TOKEN=your_key
110
- export ENV_URL=http://localhost:7860
111
- python inference.py
112
- ```
113
-
114
- Notes:
115
-
116
- - `API_BASE_URL` and `MODEL_NAME` have defaults in `inference.py`
117
- - `HF_TOKEN` is optional because the script falls back to a deterministic heuristic baseline
118
- - `LOCAL_IMAGE_NAME` is read for evaluator compatibility when using a local Docker image launcher
119
-
120
- ### Docker / Hugging Face Spaces
121
-
122
- ```bash
123
- docker build -t acre .
124
- docker run -p 7860:7860 \
125
- -e API_BASE_URL=https://api.openai.com/v1 \
126
- -e MODEL_NAME=gpt-4o-mini \
127
- -e HF_TOKEN=your_key \
128
- -e ENV_URL=http://localhost:7860 \
129
- acre
130
- ```
131
-
132
- The repository is configured for a Docker-based Hugging Face Space and includes the `openenv` tag in the front matter.
133
-
134
- ## Validation
135
-
136
- Run the repository validator:
137
-
138
- ```bash
139
- python validate.py --url http://localhost:7860
140
- ```
141
-
142
- When using the official hackathon tooling, also run:
143
-
144
- ```bash
145
- openenv validate
146
- ```
147
-
148
- ## Interactive Demo
149
-
150
- Start the server and open:
151
-
152
- ```text
153
- http://localhost:7860/demo
154
- ```
155
-
156
- The demo shows:
157
-
158
- - Original code
159
- - Optimized code
160
- - Unified diff
161
- - Per-step action and reward logs
162
-
163
- ## Baseline Performance Scores
164
-
165
- The deterministic fallback policy used by `inference.py` produces the following reproducible task scores:
166
-
167
- | Task | Score |
168
- |---|---|
169
- | `rename_variables` | 1.0 |
170
- | `remove_dead_code` | 1.0 |
171
- | `full_refactor` | 1.0 |
172
- | Average | 1.0 |
173
-
174
- These scores come from the built-in heuristic policy with `HF_TOKEN` unset, which keeps the baseline reproducible across runs.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/acre/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- """
2
- ACRE (Autonomous Code Refactoring Environment).
3
-
4
- Package skeleton for an RL-based code refactoring system.
5
- """
6
-
7
- __all__ = [
8
- "env",
9
- "actions",
10
- "datasets",
11
- "training",
12
- "utils",
13
- ]
14
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/acre/actions/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- """Action definitions and transformations for ACRE."""
2
-
3
- from .transformations import Transformation, TransformationResult
4
-
5
- __all__ = ["Transformation", "TransformationResult"]
6
-
 
 
 
 
 
 
 
ACRE_FINAL/acre/actions/transformations.py DELETED
@@ -1,518 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import ast
4
- import copy
5
- from dataclasses import dataclass
6
- from itertools import zip_longest
7
- from typing import Any, Dict, Protocol
8
-
9
-
10
- @dataclass(frozen=True)
11
- class TransformationResult:
12
- """Output of applying a transformation (placeholder)."""
13
-
14
- code: str
15
- changed: bool
16
- metadata: Dict[str, Any]
17
-
18
-
19
- class Transformation(Protocol):
20
- """Protocol for a code transformation."""
21
-
22
- name: str
23
-
24
- def apply(self, code: str) -> TransformationResult: ...
25
-
26
-
27
- def noop_transformation(code: str) -> TransformationResult:
28
- """Baseline transformation that leaves code unchanged."""
29
- return TransformationResult(code=code, changed=False, metadata={"kind": "noop"})
30
-
31
-
32
- def _finalize_result(*, original: str, out: str, meta: Dict[str, Any]) -> TransformationResult:
33
- """
34
- Standardize metadata across transformations.
35
-
36
- - Adds `lines_changed` and `impact` for explainability/metrics.
37
- - Ensures formatting-only changes don't count as `changed`.
38
- """
39
-
40
- def _count_lines_changed(a: str, b: str) -> int:
41
- a_lines = a.splitlines()
42
- b_lines = b.splitlines()
43
- changed = 0
44
- for x, y in zip_longest(a_lines, b_lines, fillvalue=None):
45
- if x != y:
46
- changed += 1
47
- return int(changed)
48
-
49
- lines_changed = _count_lines_changed(original, out)
50
-
51
- # Fallback identity check: AST round-trips can reformat without changing meaning.
52
- # If the textual content is the same after stripping, treat it as unchanged.
53
- if out.strip() == original.strip():
54
- meta["success"] = False
55
- meta["lines_changed"] = 0
56
- meta["impact"] = "low"
57
- return TransformationResult(code=original, changed=False, metadata=meta)
58
-
59
- meta["lines_changed"] = lines_changed
60
- meta["impact"] = "high" if lines_changed >= 3 else "low"
61
- meta["success"] = True
62
- return TransformationResult(code=out, changed=True, metadata=meta)
63
-
64
-
65
- def _unchanged(*, code: str, meta: Dict[str, Any]) -> TransformationResult:
66
- meta.setdefault("success", False)
67
- meta.setdefault("lines_changed", 0)
68
- meta.setdefault("impact", "low")
69
- return TransformationResult(code=code, changed=False, metadata=meta)
70
-
71
-
72
- def rename_variable(code: str) -> TransformationResult:
73
- """
74
- Rename simple, generic variable names to more descriptive ones.
75
-
76
- Hackathon-scope heuristic:
77
- - Rename generic names in priority order: x, tmp, i.
78
- - Uses descriptive base names and avoids collisions.
79
- - Applies to Name nodes and function args.
80
- """
81
- meta: Dict[str, Any] = {"type": "rename_variable", "success": False}
82
- try:
83
- tree = ast.parse(code)
84
-
85
- class _NameCollector(ast.NodeVisitor):
86
- def __init__(self) -> None:
87
- self.names: set[str] = set()
88
-
89
- def visit_Name(self, node: ast.Name) -> None: # noqa: N802
90
- self.names.add(node.id)
91
-
92
- def visit_arg(self, node: ast.arg) -> None: # noqa: N802
93
- self.names.add(node.arg)
94
-
95
- collector = _NameCollector()
96
- collector.visit(tree)
97
-
98
- rename_plan = [
99
- ("x", "value"),
100
- ("tmp", "temp_value"),
101
- ("i", "index"),
102
- ]
103
-
104
- old = ""
105
- base_new = "value"
106
- for candidate_old, candidate_base in rename_plan:
107
- if candidate_old in collector.names:
108
- old = candidate_old
109
- base_new = candidate_base
110
- break
111
-
112
- if not old:
113
- return _unchanged(code=code, meta=meta)
114
-
115
- new = base_new
116
- i = 1
117
- while new in collector.names:
118
- new = f"{base_new}{i}"
119
- i += 1
120
-
121
- class _Renamer(ast.NodeTransformer):
122
- def __init__(self, old_name: str, new_name: str) -> None:
123
- self.old_name = old_name
124
- self.new_name = new_name
125
- self.changed = False
126
-
127
- def visit_Name(self, node: ast.Name) -> ast.AST: # noqa: N802
128
- if node.id == self.old_name:
129
- self.changed = True
130
- return ast.copy_location(ast.Name(id=self.new_name, ctx=node.ctx), node)
131
- return node
132
-
133
- def visit_arg(self, node: ast.arg) -> ast.AST: # noqa: N802
134
- if node.arg == self.old_name:
135
- self.changed = True
136
- new_node = copy.copy(node)
137
- new_node.arg = self.new_name
138
- return new_node
139
- return node
140
-
141
- renamer = _Renamer(old, new)
142
- tree = renamer.visit(tree)
143
- ast.fix_missing_locations(tree)
144
-
145
- if not renamer.changed:
146
- return _unchanged(code=code, meta=meta)
147
-
148
- out = ast.unparse(tree)
149
- meta["old"] = old
150
- meta["new"] = new
151
- # Renames tend to be small diffs; label as low impact unless the diff is large.
152
- return _finalize_result(original=code, out=out, meta=meta)
153
- except Exception:
154
- return _unchanged(code=code, meta=meta)
155
-
156
-
157
- def remove_dead_code(code: str) -> TransformationResult:
158
- """
159
- Remove simple dead code patterns.
160
-
161
- Hackathon-scope heuristics:
162
- - Drop statements after `return` / `raise` in the same block.
163
- - Remove `if False: ...` blocks (keep `else` if present).
164
- - Remove assignments to unused names in a block (very simple check).
165
- """
166
- meta: Dict[str, Any] = {"type": "remove_dead_code", "success": False}
167
-
168
- try:
169
- tree = ast.parse(code)
170
-
171
- def _is_const_bool(expr: ast.AST, value: bool) -> bool:
172
- return isinstance(expr, ast.Constant) and isinstance(expr.value, bool) and expr.value is value
173
-
174
- class _LoadNameCollector(ast.NodeVisitor):
175
- def __init__(self) -> None:
176
- self.loaded: set[str] = set()
177
-
178
- def visit_Name(self, node: ast.Name) -> None: # noqa: N802
179
- if isinstance(node.ctx, ast.Load):
180
- self.loaded.add(node.id)
181
-
182
- class _DeadCode(ast.NodeTransformer):
183
- def __init__(self) -> None:
184
- self.changed = False
185
-
186
- def _prune_unreachable(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
187
- out: list[ast.stmt] = []
188
- unreachable = False
189
- for s in stmts:
190
- if unreachable:
191
- self.changed = True
192
- continue
193
- out.append(s)
194
- if isinstance(s, (ast.Return, ast.Raise)):
195
- unreachable = True
196
- return out
197
-
198
- def _remove_unused_assigns(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
199
- collector = _LoadNameCollector()
200
- for s in stmts:
201
- collector.visit(s)
202
- used = collector.loaded
203
-
204
- out: list[ast.stmt] = []
205
- for s in stmts:
206
- if isinstance(s, ast.Assign) and all(isinstance(t, ast.Name) for t in s.targets):
207
- targets = [t.id for t in s.targets if isinstance(t, ast.Name)]
208
- # Remove only if *all* assigned names are unused.
209
- if targets and all(t not in used for t in targets):
210
- self.changed = True
211
- continue
212
- if isinstance(s, ast.AnnAssign) and isinstance(s.target, ast.Name):
213
- if s.target.id not in used:
214
- self.changed = True
215
- continue
216
- out.append(s)
217
- return out
218
-
219
- def _clean_block(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
220
- # First apply transformations inside statements.
221
- visited = [self.visit(s) for s in stmts]
222
- flat: list[ast.stmt] = []
223
- for s in visited:
224
- if s is None:
225
- self.changed = True
226
- continue
227
- if isinstance(s, list):
228
- flat.extend([x for x in s if isinstance(x, ast.stmt)])
229
- self.changed = True
230
- else:
231
- flat.append(s)
232
-
233
- flat = self._prune_unreachable(flat)
234
- flat = self._remove_unused_assigns(flat)
235
- return flat
236
-
237
- def visit_Module(self, node: ast.Module) -> ast.AST: # noqa: N802
238
- node.body = self._clean_block(node.body)
239
- return node
240
-
241
- def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: # noqa: N802
242
- node.body = self._clean_block(node.body)
243
- return node
244
-
245
- def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: # noqa: N802
246
- node.body = self._clean_block(node.body)
247
- return node
248
-
249
- def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]: # noqa: N802
250
- node = self.generic_visit(node)
251
- if _is_const_bool(node.test, False):
252
- self.changed = True
253
- return node.orelse or []
254
- return node
255
-
256
- def visit_While(self, node: ast.While) -> ast.AST | None: # noqa: N802
257
- node = self.generic_visit(node)
258
- if _is_const_bool(node.test, False):
259
- self.changed = True
260
- return None
261
- return node
262
-
263
- dc = _DeadCode()
264
- tree = dc.visit(tree)
265
- ast.fix_missing_locations(tree)
266
- if not dc.changed:
267
- return _unchanged(code=code, meta=meta)
268
-
269
- out = ast.unparse(tree)
270
- return _finalize_result(original=code, out=out, meta=meta)
271
- except Exception:
272
- return _unchanged(code=code, meta=meta)
273
-
274
-
275
- def simplify_loops(code: str) -> TransformationResult:
276
- """
277
- Simplify very basic loop patterns into more pythonic forms.
278
-
279
- Supported pattern (only when adjacent in the same block):
280
- - xs = []
281
- for t in it:
282
- xs.append(expr)
283
- => xs = [expr for t in it]
284
- """
285
- meta: Dict[str, Any] = {"type": "simplify_loops", "success": False}
286
- try:
287
- tree = ast.parse(code)
288
-
289
- class _LoopSimplifier(ast.NodeTransformer):
290
- def __init__(self) -> None:
291
- self.changed = False
292
-
293
- def _simplify_body(self, body: list[ast.stmt]) -> list[ast.stmt]:
294
- out: list[ast.stmt] = []
295
- i = 0
296
- while i < len(body):
297
- cur = body[i]
298
- nxt = body[i + 1] if i + 1 < len(body) else None
299
-
300
- if (
301
- isinstance(cur, ast.Assign)
302
- and len(cur.targets) == 1
303
- and isinstance(cur.targets[0], ast.Name)
304
- and isinstance(cur.value, ast.List)
305
- and cur.value.elts == []
306
- and isinstance(nxt, ast.For)
307
- and len(nxt.body) == 1
308
- and isinstance(nxt.body[0], ast.Expr)
309
- and isinstance(nxt.body[0].value, ast.Call)
310
- ):
311
- list_name = cur.targets[0].id
312
- call = nxt.body[0].value
313
- if (
314
- isinstance(call.func, ast.Attribute)
315
- and isinstance(call.func.value, ast.Name)
316
- and call.func.value.id == list_name
317
- and call.func.attr == "append"
318
- and len(call.args) == 1
319
- and not call.keywords
320
- ):
321
- # Build list comprehension: [call.args[0] for <target> in <iter>]
322
- comp = ast.ListComp(
323
- elt=call.args[0],
324
- generators=[
325
- ast.comprehension(
326
- target=nxt.target,
327
- iter=nxt.iter,
328
- ifs=[],
329
- is_async=0,
330
- )
331
- ],
332
- )
333
- new_assign = ast.Assign(targets=[ast.Name(id=list_name, ctx=ast.Store())], value=comp)
334
- out.append(ast.copy_location(new_assign, cur))
335
- self.changed = True
336
- i += 2
337
- continue
338
-
339
- out.append(cur)
340
- i += 1
341
-
342
- return out
343
-
344
- def visit_Module(self, node: ast.Module) -> ast.AST: # noqa: N802
345
- node = self.generic_visit(node)
346
- node.body = self._simplify_body(node.body)
347
- return node
348
-
349
- def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: # noqa: N802
350
- node = self.generic_visit(node)
351
- node.body = self._simplify_body(node.body)
352
- return node
353
-
354
- def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: # noqa: N802
355
- node = self.generic_visit(node)
356
- node.body = self._simplify_body(node.body)
357
- return node
358
-
359
- simp = _LoopSimplifier()
360
- tree = simp.visit(tree)
361
- ast.fix_missing_locations(tree)
362
- if not simp.changed:
363
- return _unchanged(code=code, meta=meta)
364
-
365
- out = ast.unparse(tree)
366
- return _finalize_result(original=code, out=out, meta=meta)
367
- except Exception:
368
- return _unchanged(code=code, meta=meta)
369
-
370
-
371
- def simplify_loop(code: str) -> TransformationResult:
372
- # Backwards-compatible alias for the environment's action mapping.
373
- return simplify_loops(code)
374
-
375
-
376
- def optimize_condition(code: str) -> TransformationResult:
377
- """
378
- Simplify redundant boolean conditions.
379
-
380
- Hackathon-scope heuristics:
381
- - Replace `if True:` with its body; `if False:` with `else` (if present).
382
- - Simplify `not not X` -> `X`.
383
- - Simplify comparisons to True/False: `X == True` -> `X`, `X == False` -> `not X`.
384
- """
385
- meta: Dict[str, Any] = {"type": "optimize_condition", "success": False}
386
- try:
387
- tree = ast.parse(code)
388
-
389
- def _is_bool_const(node: ast.AST, value: bool) -> bool:
390
- return isinstance(node, ast.Constant) and isinstance(node.value, bool) and node.value is value
391
-
392
- class _CondOpt(ast.NodeTransformer):
393
- def __init__(self) -> None:
394
- self.changed = False
395
-
396
- def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST: # noqa: N802
397
- node = self.generic_visit(node)
398
- if isinstance(node.op, ast.Not) and isinstance(node.operand, ast.UnaryOp) and isinstance(node.operand.op, ast.Not):
399
- self.changed = True
400
- return node.operand.operand
401
- return node
402
-
403
- def visit_Compare(self, node: ast.Compare) -> ast.AST: # noqa: N802
404
- node = self.generic_visit(node)
405
- if len(node.ops) == 1 and len(node.comparators) == 1:
406
- op = node.ops[0]
407
- rhs = node.comparators[0]
408
- if isinstance(op, (ast.Eq, ast.Is)) and _is_bool_const(rhs, True):
409
- self.changed = True
410
- return node.left
411
- if isinstance(op, (ast.Eq, ast.Is)) and _is_bool_const(rhs, False):
412
- self.changed = True
413
- return ast.UnaryOp(op=ast.Not(), operand=node.left)
414
- return node
415
-
416
- def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]: # noqa: N802
417
- node = self.generic_visit(node)
418
- if _is_bool_const(node.test, True):
419
- self.changed = True
420
- return node.body
421
- if _is_bool_const(node.test, False):
422
- self.changed = True
423
- return node.orelse or []
424
- return node
425
-
426
- opt = _CondOpt()
427
- tree = opt.visit(tree)
428
- ast.fix_missing_locations(tree)
429
- if not opt.changed:
430
- return _unchanged(code=code, meta=meta)
431
-
432
- out = ast.unparse(tree)
433
- return _finalize_result(original=code, out=out, meta=meta)
434
- except Exception:
435
- return _unchanged(code=code, meta=meta)
436
-
437
-
438
- def inline_function(code: str) -> TransformationResult:
439
- """
440
- Inline very simple functions into their call sites.
441
-
442
- Supported pattern:
443
- - def f(a, b): return <expr using only a,b>
444
- - Replace calls: f(x, y) -> <expr with a->x, b->y>
445
- Only handles module-level functions and positional args.
446
- """
447
- meta: Dict[str, Any] = {"type": "inline_function", "success": False}
448
- try:
449
- tree = ast.parse(code)
450
-
451
- simple_fns: Dict[str, tuple[list[str], ast.AST]] = {}
452
- for node in tree.body:
453
- if not isinstance(node, ast.FunctionDef):
454
- continue
455
- if node.decorator_list:
456
- continue
457
- args = node.args
458
- if args.vararg or args.kwarg or args.kwonlyargs or args.defaults or args.posonlyargs:
459
- continue
460
- if len(node.body) != 1 or not isinstance(node.body[0], ast.Return) or node.body[0].value is None:
461
- continue
462
- arg_names = [a.arg for a in args.args]
463
- # Ensure the return expression only references the function's args.
464
- referenced: set[str] = set()
465
-
466
- class _Ref(ast.NodeVisitor):
467
- def visit_Name(self, n: ast.Name) -> None: # noqa: N802
468
- if isinstance(n.ctx, ast.Load):
469
- referenced.add(n.id)
470
-
471
- _Ref().visit(node.body[0].value)
472
- if not referenced.issubset(set(arg_names)):
473
- continue
474
- simple_fns[node.name] = (arg_names, node.body[0].value)
475
-
476
- if not simple_fns:
477
- return _unchanged(code=code, meta=meta)
478
-
479
- class _Substitute(ast.NodeTransformer):
480
- def __init__(self, mapping: Dict[str, ast.AST]) -> None:
481
- self.mapping = mapping
482
-
483
- def visit_Name(self, n: ast.Name) -> ast.AST: # noqa: N802
484
- if isinstance(n.ctx, ast.Load) and n.id in self.mapping:
485
- return copy.deepcopy(self.mapping[n.id])
486
- return n
487
-
488
- class _Inliner(ast.NodeTransformer):
489
- def __init__(self) -> None:
490
- self.changed = False
491
-
492
- def visit_Call(self, node: ast.Call) -> ast.AST: # noqa: N802
493
- node = self.generic_visit(node)
494
- if not isinstance(node.func, ast.Name):
495
- return node
496
- fn = simple_fns.get(node.func.id)
497
- if fn is None:
498
- return node
499
- arg_names, expr = fn
500
- if node.keywords or len(node.args) != len(arg_names):
501
- return node
502
- mapping = {name: arg for name, arg in zip(arg_names, node.args, strict=True)}
503
- new_expr = _Substitute(mapping).visit(copy.deepcopy(expr))
504
- self.changed = True
505
- return ast.copy_location(new_expr, node)
506
-
507
- inliner = _Inliner()
508
- tree = inliner.visit(tree)
509
- ast.fix_missing_locations(tree)
510
- if not inliner.changed:
511
- return _unchanged(code=code, meta=meta)
512
-
513
- out = ast.unparse(tree)
514
- meta["inlined"] = sorted(simple_fns.keys())
515
- return _finalize_result(original=code, out=out, meta=meta)
516
- except Exception:
517
- return _unchanged(code=code, meta=meta)
518
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/acre/datasets/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- """Datasets and sample code providers for ACRE."""
2
-
3
- from .code_samples import CodeSample, CodeSampleDataset
4
-
5
- __all__ = ["CodeSample", "CodeSampleDataset"]
6
-
 
 
 
 
 
 
 
ACRE_FINAL/acre/datasets/code_samples.py DELETED
@@ -1,34 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from typing import Iterable, Iterator, List, Optional
5
-
6
-
7
- @dataclass(frozen=True)
8
- class CodeSample:
9
- """A single code sample (placeholder)."""
10
-
11
- id: str
12
- language: str
13
- code: str
14
-
15
-
16
- class CodeSampleDataset:
17
- """
18
- Minimal in-memory dataset stub.
19
-
20
- Later versions can back this with files, Git repos, or benchmark suites.
21
- """
22
-
23
- def __init__(self, samples: Optional[Iterable[CodeSample]] = None) -> None:
24
- self._samples: List[CodeSample] = list(samples or [])
25
-
26
- def __len__(self) -> int:
27
- return len(self._samples)
28
-
29
- def __iter__(self) -> Iterator[CodeSample]:
30
- return iter(self._samples)
31
-
32
- def add(self, sample: CodeSample) -> None:
33
- self._samples.append(sample)
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/acre/demo.py DELETED
@@ -1,185 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- import random
5
- import sys
6
- from typing import Any, Optional, Tuple
7
-
8
- from acre.datasets.code_samples import CodeSample, CodeSampleDataset
9
- from acre.env.refactor_env import RefactorEnv
10
-
11
-
12
- def _load_model(path: str):
13
- """Load a Stable-Baselines3 PPO model if available; otherwise return None."""
14
- if not os.path.exists(path):
15
- return None
16
- try:
17
- from stable_baselines3 import PPO
18
- except Exception:
19
- return None
20
- try:
21
- return PPO.load(path)
22
- except Exception:
23
- return None
24
-
25
-
26
- def _messy_sample_code() -> str:
27
- # Intentionally "messy" but valid Python for demo purposes.
28
- return (
29
- "def add(a,b):\n"
30
- " x=0\n"
31
- " for i in range(a):\n"
32
- " x=x+1\n"
33
- " if True:\n"
34
- " x = x\n"
35
- " if False:\n"
36
- " y=123\n"
37
- " else:\n"
38
- " y=0\n"
39
- " def f(p,q):\n"
40
- " return p+q\n"
41
- " r = f(x,y)\n"
42
- " return r\n"
43
- )
44
-
45
-
46
- def _format_code_block(code: str) -> str:
47
- return "\n".join(f" {line}" for line in code.rstrip().splitlines()) + "\n"
48
-
49
-
50
- def _safe_print(text: str) -> None:
51
- """
52
- Print text safely across Windows consoles (some default encodings can't print emojis).
53
- """
54
- encoding = sys.stdout.encoding or "utf-8"
55
- try:
56
- text.encode(encoding)
57
- print(text, flush=True)
58
- except Exception:
59
- # Fall back to ASCII-friendly markers if emojis can't be encoded.
60
- safe = text.replace("✅", "[OK]").replace("⚠️", "[WARN]").replace("⚠", "[WARN]")
61
- print(safe, flush=True)
62
-
63
-
64
- def _compute_runtime(executor: Any, code: str) -> float:
65
- """Best-effort runtime metric using the current executor contract."""
66
- try:
67
- res = executor.run(code, filename="demo.py")
68
- if getattr(res, "exit_code", 1) == 0 and isinstance(getattr(res, "metrics", None), dict):
69
- return float(res.metrics.get("runtime_s", 0.0) or 0.0)
70
- except Exception:
71
- pass
72
- return 0.0
73
-
74
-
75
- def _choose_action(model: Any, obs, env: RefactorEnv, rng: random.Random) -> Tuple[int, str]:
76
- """Choose an action from the model, falling back to random."""
77
- n_actions = int(getattr(getattr(env, "action_space", None), "n", 5))
78
- if model is None:
79
- a = int(rng.randint(0, n_actions - 1))
80
- return a, "random"
81
-
82
- try:
83
- action, _state = model.predict(obs, deterministic=True)
84
- # SB3 may return scalar or 1-element array.
85
- if hasattr(action, "__len__"):
86
- a = int(action[0])
87
- else:
88
- a = int(action)
89
- return a, "ppo"
90
- except Exception:
91
- a = int(rng.randint(0, n_actions - 1))
92
- return a, "random"
93
-
94
-
95
- def run_demo(*, model_path: str = "acre_agent.zip", seed: int = 0) -> None:
96
- rng = random.Random(seed)
97
-
98
- # Create a dataset with one messy sample so `reset()` loads it deterministically.
99
- dataset = CodeSampleDataset(
100
- [
101
- CodeSample(
102
- id="demo_sample",
103
- language="python",
104
- code=_messy_sample_code(),
105
- )
106
- ]
107
- )
108
- env = RefactorEnv(dataset=dataset, seed=seed)
109
-
110
- model = _load_model(model_path)
111
- model_status = "loaded" if model is not None else "not found (using random actions)"
112
-
113
- # Reset and capture the original code/metrics.
114
- obs, info = env.reset()
115
- original_code = getattr(env, "_code", "")
116
- original_complexity = float(getattr(env, "_compute_complexity")(original_code))
117
- original_runtime = _compute_runtime(env.executor, original_code)
118
-
119
- print("=" * 72)
120
- print("ACRE: Autonomous RL Code Refactoring Agent (5-step episode)")
121
- print(f"Model: {model_path} -> {model_status}")
122
- print(f"Sample: {info.get('sample_id')} ({info.get('language')})")
123
- print("=" * 72)
124
- print("\nORIGINAL CODE:\n")
125
- print(_format_code_block(original_code))
126
-
127
- total_reward = 0.0
128
- successful_transformations = 0
129
- steps_taken = 0
130
-
131
- for step_idx in range(1, 6):
132
- action, policy = _choose_action(model, obs, env, rng)
133
- obs, reward, terminated, truncated, step_info = env.step(action)
134
- total_reward += float(reward)
135
- steps_taken = step_idx
136
-
137
- action_name = step_info.get("action_name", "unknown")
138
- transform_meta = step_info.get("transform", {})
139
- if isinstance(transform_meta, dict) and bool(transform_meta.get("success", False)):
140
- successful_transformations += 1
141
- transformed_code = getattr(env, "_code", "")
142
-
143
- print("-" * 72)
144
- print(f"STEP {step_idx}/5")
145
- print(f"policy={policy} action={action} ({action_name})")
146
- print(f"transform={transform_meta}")
147
- print(f"reward={float(reward):.2f} components={step_info.get('reward_components')}")
148
- print("\nUPDATED CODE:\n")
149
- print(_format_code_block(transformed_code))
150
-
151
- if terminated or truncated:
152
- break
153
-
154
- final_code = getattr(env, "_code", "")
155
- final_complexity = float(getattr(env, "_compute_complexity")(final_code))
156
- final_runtime = _compute_runtime(env.executor, final_code)
157
-
158
- print("=" * 72)
159
- print("FINAL SUMMARY")
160
- print("=" * 72)
161
- print(f"total_reward: {total_reward:.2f}")
162
- print(f"complexity: {original_complexity:.0f} -> {final_complexity:.0f}")
163
- print(f"runtime_s: {original_runtime:.4f} -> {final_runtime:.4f}")
164
-
165
- complexity_improvement = ((original_complexity - final_complexity) / max(original_complexity, 1.0)) * 100.0
166
- print(f"complexity improvement: {complexity_improvement:.2f}%")
167
-
168
- print("\nCHANGES APPLIED:")
169
- print(f"- Total steps: {steps_taken}")
170
- print(f"- Successful transformations: {successful_transformations}")
171
-
172
- if total_reward > 0:
173
- _safe_print("\n✅ Code improved successfully")
174
- else:
175
- _safe_print("\n⚠️ No significant improvement")
176
-
177
- print("\nFINAL CODE:\n")
178
- print(_format_code_block(final_code))
179
-
180
- env.close()
181
-
182
-
183
- if __name__ == "__main__":
184
- run_demo()
185
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/acre/main.py DELETED
@@ -1,39 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
-
5
- from acre.training.train_agent import TrainConfig, train
6
-
7
-
8
- def _build_parser() -> argparse.ArgumentParser:
9
- parser = argparse.ArgumentParser(prog="acre", description="ACRE: Autonomous Code Refactoring Environment")
10
- sub = parser.add_subparsers(dest="command", required=False)
11
-
12
- train_p = sub.add_parser("train", help="Run training (stub)")
13
- train_p.add_argument("--total-steps", type=int, default=100, help="Total training steps (stub)")
14
-
15
- sub.add_parser("demo", help="Run a small demo (stub)")
16
-
17
- return parser
18
-
19
-
20
- def run_demo() -> None:
21
- # Placeholder for a future interactive/demo flow.
22
- print("ACRE demo mode is not implemented yet.")
23
-
24
-
25
- def main(argv: list[str] | None = None) -> None:
26
- parser = _build_parser()
27
- args = parser.parse_args(argv)
28
-
29
- if args.command == "demo":
30
- run_demo()
31
- return
32
-
33
- total_steps = getattr(args, "total_steps", 100)
34
- train(config=TrainConfig(total_steps=total_steps))
35
-
36
-
37
- if __name__ == "__main__":
38
- main()
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/acre/tasks/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from acre.tasks.task_registry import Task, TaskRegistry
2
-
3
- __all__ = ["Task", "TaskRegistry"]
 
 
 
 
ACRE_FINAL/acre/tasks/task_registry.py DELETED
@@ -1,222 +0,0 @@
1
- """
2
- Three OpenEnv tasks with AST-based graders scoring 0.0-1.0.
3
- """
4
- from __future__ import annotations
5
-
6
- import ast
7
- from dataclasses import dataclass
8
- from typing import Callable, Dict, List, Optional
9
-
10
-
11
- @dataclass
12
- class Task:
13
- id: str
14
- name: str
15
- description: str
16
- difficulty: str
17
- initial_code: str
18
- _grade_fn: Callable[[str], float]
19
-
20
- def grade(self, code: str) -> float:
21
- """Return a score in [0.0, 1.0]."""
22
- try:
23
- return float(min(1.0, max(0.0, self._grade_fn(code))))
24
- except Exception:
25
- return 0.0
26
-
27
-
28
- # ---------------------------------------------------------------------------
29
- # Task 1 — Easy: Rename generic variables
30
- # ---------------------------------------------------------------------------
31
- _EASY_CODE = """\
32
- def compute(x, y, tmp):
33
- tmp = x + y
34
- x = tmp * 2
35
- result = x
36
- return result
37
- """
38
-
39
-
40
- def _grade_easy(code: str) -> float:
41
- """Score = fraction of generic names (x, tmp) removed from all scopes."""
42
- generic = {"x", "tmp"}
43
- try:
44
- tree = ast.parse(code)
45
- except SyntaxError:
46
- return 0.0
47
-
48
- remaining: set[str] = set()
49
-
50
- class _Collector(ast.NodeVisitor):
51
- def visit_Name(self, node: ast.Name) -> None:
52
- if node.id in generic:
53
- remaining.add(node.id)
54
- self.generic_visit(node)
55
-
56
- def visit_arg(self, node: ast.arg) -> None:
57
- if node.arg in generic:
58
- remaining.add(node.arg)
59
- self.generic_visit(node)
60
-
61
- _Collector().visit(tree)
62
- renamed = len(generic - remaining)
63
- return renamed / len(generic)
64
-
65
-
66
- # ---------------------------------------------------------------------------
67
- # Task 2 — Medium: Remove dead code
68
- # ---------------------------------------------------------------------------
69
- _MEDIUM_CODE = """\
70
- def process(data):
71
- result = []
72
- for item in data:
73
- result.append(item * 2)
74
- if False:
75
- print("never runs")
76
- unused_var = 42
77
- return result
78
- print("unreachable")
79
- """
80
-
81
-
82
- def _grade_medium(code: str) -> float:
83
- """Score = fraction of dead-code patterns eliminated (3 checks, ~0.33 each)."""
84
- try:
85
- tree = ast.parse(code)
86
- except SyntaxError:
87
- return 0.0
88
-
89
- source = ast.unparse(tree)
90
- score = 0.0
91
-
92
- # Check 1: if-False block removed
93
- if "if False" not in source:
94
- score += 1 / 3
95
-
96
- # Check 2: unused_var assignment removed
97
- if "unused_var" not in source:
98
- score += 1 / 3
99
-
100
- # Check 3: list comprehension used (loop simplified)
101
- has_listcomp = any(isinstance(n, ast.ListComp) for n in ast.walk(tree))
102
- if has_listcomp:
103
- score += 1 / 3
104
-
105
- return score
106
-
107
-
108
- # ---------------------------------------------------------------------------
109
- # Task 3 — Hard: Full refactor
110
- # ---------------------------------------------------------------------------
111
- _HARD_CODE = """\
112
- def add(p, q):
113
- return p + q
114
-
115
- def compute(x, data, tmp):
116
- result = []
117
- for item in data:
118
- result.append(item * 2)
119
- if False:
120
- y = 999
121
- if True:
122
- val = add(x, tmp)
123
- unused = 0
124
- flag = not not True
125
- return val
126
- print("dead")
127
- """
128
-
129
-
130
- def _grade_hard(code: str) -> float:
131
- """Score = fraction of 5 quality checks passed."""
132
- try:
133
- tree = ast.parse(code)
134
- except SyntaxError:
135
- return 0.0
136
-
137
- source = ast.unparse(tree)
138
- checks = 0
139
-
140
- # 1. No generic variable names x/tmp in function signature or body
141
- has_generic = False
142
-
143
- class _GenCheck(ast.NodeVisitor):
144
- def visit_arg(self, node: ast.arg) -> None:
145
- nonlocal has_generic
146
- if node.arg in {"x", "tmp"}:
147
- has_generic = True
148
-
149
- _GenCheck().visit(tree)
150
- if not has_generic:
151
- checks += 1
152
-
153
- # 2. No if False block
154
- if "if False" not in source:
155
- checks += 1
156
-
157
- # 3. if True removed (body inlined)
158
- if "if True" not in source:
159
- checks += 1
160
-
161
- # 4. List comprehension used
162
- if any(isinstance(n, ast.ListComp) for n in ast.walk(tree)):
163
- checks += 1
164
-
165
- # 5. add() call inlined (no call to 'add')
166
- calls = [n for n in ast.walk(tree) if isinstance(n, ast.Call)]
167
- fn_names = {c.func.id for c in calls if isinstance(c.func, ast.Name)}
168
- if "add" not in fn_names:
169
- checks += 1
170
-
171
- return checks / 5
172
-
173
-
174
- # ---------------------------------------------------------------------------
175
- # Registry
176
- # ---------------------------------------------------------------------------
177
-
178
- class TaskRegistry:
179
- def __init__(self) -> None:
180
- self._tasks: Dict[str, Task] = {}
181
- self._register_all()
182
-
183
- def _register_all(self) -> None:
184
- self._tasks["rename_variables"] = Task(
185
- id="rename_variables",
186
- name="Rename Variables (Easy)",
187
- description="Rename generic variable names (x, tmp) to descriptive ones",
188
- difficulty="easy",
189
- initial_code=_EASY_CODE,
190
- _grade_fn=_grade_easy,
191
- )
192
- self._tasks["remove_dead_code"] = Task(
193
- id="remove_dead_code",
194
- name="Remove Dead Code (Medium)",
195
- description="Remove unreachable code, if False blocks, and unused variables",
196
- difficulty="medium",
197
- initial_code=_MEDIUM_CODE,
198
- _grade_fn=_grade_medium,
199
- )
200
- self._tasks["full_refactor"] = Task(
201
- id="full_refactor",
202
- name="Full Refactor (Hard)",
203
- description="Apply all transformations: rename, dead code, loops, conditions, inlining",
204
- difficulty="hard",
205
- initial_code=_HARD_CODE,
206
- _grade_fn=_grade_hard,
207
- )
208
-
209
- def get_task(self, task_id: str) -> Optional[Task]:
210
- return self._tasks.get(task_id)
211
-
212
- def list_tasks(self) -> List[dict]:
213
- return [
214
- {
215
- "id": t.id,
216
- "name": t.name,
217
- "description": t.description,
218
- "difficulty": t.difficulty,
219
- "initial_code": t.initial_code,
220
- }
221
- for t in self._tasks.values()
222
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/acre/training/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- """Training utilities for ACRE."""
2
-
3
- from .train_agent import TrainConfig, train
4
-
5
- __all__ = ["TrainConfig", "train"]
6
-
 
 
 
 
 
 
 
ACRE_FINAL/acre/training/train_agent.py DELETED
@@ -1,75 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from typing import Optional
5
-
6
- from acre.env.refactor_env import RefactorEnv
7
-
8
-
9
- @dataclass(frozen=True)
10
- class TrainConfig:
11
- """Configuration stub for training."""
12
-
13
- total_steps: int = 5_000
14
- seed: Optional[int] = None
15
- model_path: str = "acre_agent.zip"
16
-
17
-
18
- def train(*, env: Optional[RefactorEnv] = None, config: Optional[TrainConfig] = None) -> None:
19
- """
20
- Train a PPO agent on `RefactorEnv` using Stable-Baselines3.
21
-
22
- This is intentionally lightweight (hackathon-friendly) and focuses on a
23
- working demo: basic training loop, simple logging, and saving the model.
24
- """
25
- _config = config or TrainConfig()
26
- _env = env or RefactorEnv(seed=_config.seed)
27
-
28
- try:
29
- from stable_baselines3 import PPO
30
- from stable_baselines3.common.callbacks import BaseCallback
31
- from stable_baselines3.common.monitor import Monitor
32
- from stable_baselines3.common.vec_env import DummyVecEnv
33
- except Exception as e: # pragma: no cover
34
- print("Stable-Baselines3 is required for training. Install with `pip install -r requirements.txt`.")
35
- print(f"Import error: {e}")
36
- return None
37
-
38
- class EpisodeRewardPrinter(BaseCallback):
39
- """Print episode reward when an episode ends (via Monitor)."""
40
-
41
- def __init__(self) -> None:
42
- super().__init__()
43
- self.episode_count = 0
44
-
45
- def _on_step(self) -> bool:
46
- infos = self.locals.get("infos", [])
47
- for info in infos:
48
- ep = info.get("episode") if isinstance(info, dict) else None
49
- if isinstance(ep, dict) and "r" in ep:
50
- self.episode_count += 1
51
- print(f"episode={self.episode_count} reward={ep['r']:.2f} length={int(ep.get('l', 0))}")
52
- return True
53
-
54
- # Wrap with Monitor so SB3 can compute episode stats and expose them in `info["episode"]`.
55
- def make_env() -> RefactorEnv:
56
- return Monitor(_env)
57
-
58
- vec_env = DummyVecEnv([make_env])
59
-
60
- model = PPO(
61
- policy="MlpPolicy",
62
- env=vec_env,
63
- verbose=0,
64
- seed=_config.seed,
65
- n_steps=64,
66
- batch_size=64,
67
- )
68
-
69
- print(f"Training PPO for {int(_config.total_steps)} timesteps...")
70
- model.learn(total_timesteps=int(_config.total_steps), callback=EpisodeRewardPrinter())
71
-
72
- model.save(_config.model_path)
73
- print(f"Saved model to {_config.model_path!r}")
74
- return None
75
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/acre/utils/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- """Shared utility helpers for ACRE."""
2
-
3
- from .metrics import Metric, MetricLogger
4
-
5
- __all__ = ["Metric", "MetricLogger"]
6
-
 
 
 
 
 
 
 
ACRE_FINAL/acre/utils/metrics.py DELETED
@@ -1,33 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass, field
4
- from typing import Dict, Iterable, List, Tuple
5
-
6
-
7
- @dataclass(frozen=True)
8
- class Metric:
9
- """Single scalar metric value (placeholder)."""
10
-
11
- name: str
12
- value: float
13
-
14
-
15
- @dataclass
16
- class MetricLogger:
17
- """Tiny metric logger stub."""
18
-
19
- _history: Dict[str, List[float]] = field(default_factory=dict)
20
-
21
- def log(self, metric: Metric) -> None:
22
- self._history.setdefault(metric.name, []).append(metric.value)
23
-
24
- def latest(self) -> Dict[str, float]:
25
- return {k: v[-1] for k, v in self._history.items() if v}
26
-
27
- def as_series(self) -> Dict[str, Tuple[float, ...]]:
28
- return {k: tuple(v) for k, v in self._history.items()}
29
-
30
- def extend(self, metrics: Iterable[Metric]) -> None:
31
- for m in metrics:
32
- self.log(m)
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/inference.py DELETED
@@ -1,278 +0,0 @@
1
- """
2
- ACRE inference script for OpenEnv submission evaluation.
3
-
4
- Required environment variables:
5
- API_BASE_URL: LLM API endpoint (default allowed)
6
- MODEL_NAME: model identifier (default allowed)
7
- HF_TOKEN: API token for the OpenAI-compatible endpoint
8
- ENV_URL: running ACRE server base URL
9
-
10
- Optional:
11
- LOCAL_IMAGE_NAME: present for evaluator compatibility when using a local
12
- Docker image launcher.
13
-
14
- Stdout format uses strict START / STEP / END event markers.
15
- """
16
- from __future__ import annotations
17
-
18
- import json
19
- import os
20
- import re
21
- import sys
22
- import time
23
- from typing import Dict, List, Tuple
24
-
25
- import requests
26
- from openai import OpenAI
27
-
28
- API_BASE_URL: str = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
29
- MODEL_NAME: str = os.getenv("MODEL_NAME", "gpt-4o-mini")
30
- HF_TOKEN: str | None = os.getenv("HF_TOKEN")
31
- ENV_URL: str | None = os.getenv("ENV_URL")
32
- LOCAL_IMAGE_NAME: str | None = os.getenv("LOCAL_IMAGE_NAME")
33
-
34
- TASKS: List[str] = ["rename_variables", "remove_dead_code", "full_refactor"]
35
-
36
- ACTION_MEANINGS: Dict[int, str] = {
37
- 0: "rename_variable",
38
- 1: "remove_dead_code",
39
- 2: "simplify_loop",
40
- 3: "optimize_condition",
41
- 4: "inline_function",
42
- }
43
-
44
- SYSTEM_PROMPT = """\
45
- You are an RL agent that refactors Python code. Choose one action per step.
46
-
47
- Actions:
48
- 0 rename_variable - rename generic names (x, tmp, i) to descriptive ones
49
- 1 remove_dead_code - remove unreachable stmts, if False blocks, unused vars
50
- 2 simplify_loop - convert append-loops to list comprehensions
51
- 3 optimize_condition- simplify 'not not x', 'if True/False', 'x==True'
52
- 4 inline_function - inline simple single-return module-level functions
53
-
54
- Respond ONLY with valid JSON (no markdown):
55
- {"action": <0-4>, "reason": "<one sentence>"}"""
56
-
57
-
58
- def _env_url() -> str:
59
- if ENV_URL:
60
- return ENV_URL.rstrip("/")
61
- raise RuntimeError("ENV_URL must be set before running inference.py")
62
-
63
-
64
- def _post(path: str, payload: dict | None = None) -> dict:
65
- response = requests.post(f"{_env_url()}{path}", json=payload or {}, timeout=30)
66
- response.raise_for_status()
67
- return response.json()
68
-
69
-
70
- def _get(path: str) -> dict:
71
- response = requests.get(f"{_env_url()}{path}", timeout=30)
72
- response.raise_for_status()
73
- return response.json()
74
-
75
-
76
- def reset_env(task_id: str) -> dict:
77
- return _post("/reset", {"task_id": task_id})
78
-
79
-
80
- def step_env(action: int) -> dict:
81
- return _post("/step", {"action": action})
82
-
83
-
84
- def get_state() -> dict:
85
- return _get("/state")
86
-
87
-
88
- def grade(task_id: str, code: str) -> float:
89
- response = requests.post(
90
- f"{_env_url()}/tasks/{task_id}/grade",
91
- json={"code": code},
92
- timeout=30,
93
- )
94
- response.raise_for_status()
95
- return float(response.json().get("score", 0.0))
96
-
97
-
98
- def choose_action(client: OpenAI, state: dict, task_id: str) -> Tuple[int, str]:
99
- def heuristic_action() -> Tuple[int, str]:
100
- code = str(state.get("current_code", ""))
101
- step_i = int(state.get("episode_steps", 0))
102
-
103
- has_generic = re.search(r"\b(x|tmp|i)\b", code) is not None
104
- has_if_false = re.search(r"\bif\s+False\b", code) is not None
105
- has_if_true = re.search(r"\bif\s+True\b", code) is not None
106
- has_append_loop = ".append(" in code and "for " in code
107
- has_double_not = "not not" in code
108
- has_add_call = "add(" in code
109
-
110
- if task_id == "rename_variables":
111
- if has_generic:
112
- return 0, "heuristic: remove generic names first"
113
- if has_if_false or "unused" in code:
114
- return 1, "heuristic: remove dead code"
115
- if has_append_loop:
116
- return 2, "heuristic: simplify loop"
117
- if has_if_true or has_double_not:
118
- return 3, "heuristic: optimize conditions"
119
- return 4, "heuristic: inline simple function"
120
-
121
- if task_id == "remove_dead_code":
122
- if has_if_false or "unused" in code:
123
- return 1, "heuristic: remove dead code patterns"
124
- if has_append_loop:
125
- return 2, "heuristic: convert append-loop"
126
- if has_if_true or has_double_not:
127
- return 3, "heuristic: simplify conditions"
128
- if has_generic:
129
- return 0, "heuristic: clean generic names"
130
- return 4, "heuristic: inline helper"
131
-
132
- if has_generic:
133
- return 0, "heuristic: rename generic variables"
134
- if has_append_loop:
135
- return 2, "heuristic: simplify loop into listcomp"
136
- if has_if_false or has_if_true or has_double_not:
137
- return 3, "heuristic: optimize boolean branches"
138
- if has_add_call:
139
- return 4, "heuristic: inline add() call"
140
- if step_i >= 2:
141
- return 1, "heuristic: remove remaining dead code"
142
- return 3, "heuristic: condition optimization as safe default"
143
-
144
- if not HF_TOKEN:
145
- return heuristic_action()
146
-
147
- messages = [
148
- {"role": "system", "content": SYSTEM_PROMPT},
149
- {
150
- "role": "user",
151
- "content": (
152
- f"Task: {task_id}\n"
153
- f"Steps remaining: {state.get('max_steps', 5) - state.get('episode_steps', 0)}\n"
154
- f"Complexity: {state.get('complexity', 0)}\n\n"
155
- f"Current code:\n```python\n{state.get('current_code', '')}\n```\n\n"
156
- "Choose the best action."
157
- ),
158
- },
159
- ]
160
- try:
161
- response = client.chat.completions.create(
162
- model=MODEL_NAME,
163
- messages=messages,
164
- temperature=0.0,
165
- max_tokens=120,
166
- )
167
- raw = (response.choices[0].message.content or "").strip()
168
- json_blob = raw
169
-
170
- if "{" not in json_blob or "}" not in json_blob:
171
- return heuristic_action()
172
-
173
- match = re.search(r"\{.*\}", json_blob, flags=re.DOTALL)
174
- if match:
175
- json_blob = match.group(0)
176
-
177
- parsed = json.loads(json_blob)
178
- action = int(parsed.get("action", -1))
179
- reason = str(parsed.get("reason", ""))
180
- if 0 <= action <= 4:
181
- return action, reason or "llm-selected action"
182
- return heuristic_action()
183
- except Exception:
184
- return heuristic_action()
185
-
186
-
187
- def run_episode(client: OpenAI, task_id: str, episode_num: int) -> float:
188
- reset_env(task_id)
189
- state = get_state()
190
-
191
- print(
192
- json.dumps(
193
- {
194
- "event": "START",
195
- "episode": episode_num,
196
- "task_id": task_id,
197
- "initial_complexity": state.get("complexity", 0),
198
- "initial_code_length": len(state.get("current_code", "")),
199
- "timestamp": time.time(),
200
- }
201
- ),
202
- flush=True,
203
- )
204
-
205
- cumulative_reward = 0.0
206
-
207
- for step_num in range(1, 6):
208
- action, reason = choose_action(client, state, task_id)
209
- result = step_env(action)
210
- state = get_state()
211
-
212
- reward_payload = result.get("reward", {})
213
- raw_reward = float(reward_payload.get("raw", 0.0))
214
- norm_reward = float(reward_payload.get("normalized", (raw_reward + 32) / 52))
215
- cumulative_reward += raw_reward
216
-
217
- print(
218
- json.dumps(
219
- {
220
- "event": "STEP",
221
- "episode": episode_num,
222
- "step": step_num,
223
- "action": action,
224
- "action_name": ACTION_MEANINGS.get(action, "unknown"),
225
- "reason": reason,
226
- "reward": round(raw_reward, 4),
227
- "normalized_reward": round(norm_reward, 4),
228
- "cumulative_reward": round(cumulative_reward, 4),
229
- "changed": result.get("info", {}).get("changed", False),
230
- "reward_components": reward_payload.get("components", {}),
231
- "done": result.get("done", False),
232
- }
233
- ),
234
- flush=True,
235
- )
236
-
237
- if result.get("done") or result.get("terminated") or result.get("truncated"):
238
- break
239
-
240
- final_state = get_state()
241
- task_score = grade(task_id, final_state.get("current_code", ""))
242
-
243
- print(
244
- json.dumps(
245
- {
246
- "event": "END",
247
- "episode": episode_num,
248
- "task_id": task_id,
249
- "cumulative_reward": round(cumulative_reward, 4),
250
- "normalized_cumulative": round((cumulative_reward + 32) / 52, 4),
251
- "task_score": round(task_score, 4),
252
- "final_complexity": final_state.get("complexity", 0),
253
- "timestamp": time.time(),
254
- }
255
- ),
256
- flush=True,
257
- )
258
-
259
- return task_score
260
-
261
-
262
- def main() -> None:
263
- if not ENV_URL:
264
- raise SystemExit("ENV_URL is required. Example: ENV_URL=http://localhost:7860")
265
-
266
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "dummy")
267
-
268
- scores: List[float] = []
269
- for i, task_id in enumerate(TASKS, start=1):
270
- score = run_episode(client, task_id, i)
271
- scores.append(score)
272
-
273
- avg_score = sum(scores) / len(scores) if scores else 0.0
274
- sys.exit(0 if avg_score >= 0.5 else 1)
275
-
276
-
277
- if __name__ == "__main__":
278
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/models.py DELETED
@@ -1,156 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any, Dict, List, Optional, Sequence
4
-
5
- from pydantic import BaseModel, Field
6
-
7
-
8
- class ObservationModel(BaseModel):
9
- code_length: float
10
- complexity_score: float
11
- runtime_s: float
12
- error_flag: bool
13
-
14
- @classmethod
15
- def from_vector(cls, values: Sequence[float]) -> "ObservationModel":
16
- vector = list(values)
17
- if len(vector) != 4:
18
- raise ValueError(f"observation vector must have length 4, got {len(vector)}")
19
- return cls(
20
- code_length=float(vector[0]),
21
- complexity_score=float(vector[1]),
22
- runtime_s=float(vector[2]),
23
- error_flag=bool(vector[3]),
24
- )
25
-
26
- def to_vector(self) -> List[float]:
27
- return [
28
- float(self.code_length),
29
- float(self.complexity_score),
30
- float(self.runtime_s),
31
- float(int(self.error_flag)),
32
- ]
33
-
34
-
35
- class ActionModel(BaseModel):
36
- action: int = Field(ge=0, le=4)
37
- action_name: Optional[str] = None
38
-
39
-
40
- class RewardModel(BaseModel):
41
- raw: float
42
- normalized: float = Field(ge=0.0, le=1.0)
43
- components: Dict[str, float]
44
-
45
-
46
- class HealthResponse(BaseModel):
47
- status: str
48
- env: str
49
- version: str
50
-
51
-
52
- class CompatibilityHealthResponse(BaseModel):
53
- status: str
54
- service: str
55
-
56
-
57
- class ResetRequest(BaseModel):
58
- task_id: Optional[str] = None
59
- seed: Optional[int] = None
60
- code: Optional[str] = None
61
-
62
-
63
- class StepRequest(BaseModel):
64
- action: int = Field(ge=0, le=4)
65
-
66
-
67
- class GradeRequest(BaseModel):
68
- code: str
69
-
70
-
71
- class TaskInfo(BaseModel):
72
- id: str
73
- name: str
74
- description: str
75
- difficulty: str
76
- initial_code: str
77
-
78
-
79
- class TasksResponse(BaseModel):
80
- tasks: List[TaskInfo]
81
-
82
-
83
- class GradeResponse(BaseModel):
84
- task_id: str
85
- score: float
86
- passed: bool
87
-
88
-
89
- class StateResponse(BaseModel):
90
- current_code: str
91
- episode_steps: int
92
- max_steps: int
93
- complexity: float
94
- last_runtime: float
95
- last_error: bool
96
- sample_id: Optional[str]
97
- language: Optional[str]
98
- task_id: Optional[str]
99
- observation: ObservationModel
100
- observation_vector: List[float]
101
- action_meanings: Dict[int, str]
102
-
103
-
104
- class ResetResponse(BaseModel):
105
- observation: ObservationModel
106
- observation_vector: List[float]
107
- info: Dict[str, Any]
108
- task_id: Optional[str]
109
- state: StateResponse
110
-
111
-
112
- class StepResponse(BaseModel):
113
- action: ActionModel
114
- observation: ObservationModel
115
- observation_vector: List[float]
116
- reward: RewardModel
117
- done: bool
118
- terminated: bool
119
- truncated: bool
120
- info: Dict[str, Any]
121
- state: StateResponse
122
-
123
-
124
- class OptimizeRequest(BaseModel):
125
- code: str
126
- task_id: Optional[str] = None
127
- max_steps: int = Field(default=5, ge=1, le=5)
128
- use_rl: bool = True
129
- use_llm: bool = False
130
- fallback_to_llm: bool = True
131
- rl_model_path: Optional[str] = None
132
- api_base_url: Optional[str] = None
133
- model_name: Optional[str] = None
134
- api_token: Optional[str] = None
135
-
136
-
137
- class OptimizationStep(BaseModel):
138
- step: int
139
- action: int
140
- action_name: str
141
- reason: str
142
- source: str
143
- reward: float
144
- normalized_reward: float
145
- changed: bool
146
- complexity: float
147
-
148
-
149
- class OptimizeResponse(BaseModel):
150
- original_code: str
151
- optimized_code: str
152
- diff: str
153
- steps: List[OptimizationStep]
154
- cumulative_reward: float
155
- task_id: Optional[str]
156
- task_score: Optional[float]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/openenv.yaml DELETED
@@ -1,85 +0,0 @@
1
- name: ACRE
2
- version: "1.0.0"
3
- description: >
4
- Autonomous Code Refactoring Environment - an RL environment where an
5
- agent improves Python code quality using AST-level transformations.
6
- author: "Nikhil Pratap Singh, Pranav Mangal, Ananya Gupta"
7
- entrypoint: "openenv_interface:OpenEnvRefactorEnv"
8
- tags:
9
- - openenv
10
-
11
- tasks:
12
- - id: rename_variables
13
- name: "Rename Variables (Easy)"
14
- description: "Rename generic variable names (x, tmp) to descriptive ones"
15
- difficulty: easy
16
- reward_range: [0.0, 1.0]
17
- max_steps: 5
18
-
19
- - id: remove_dead_code
20
- name: "Remove Dead Code (Medium)"
21
- description: "Remove unreachable statements, if-False blocks, and unused assignments"
22
- difficulty: medium
23
- reward_range: [0.0, 1.0]
24
- max_steps: 5
25
-
26
- - id: full_refactor
27
- name: "Full Refactor (Hard)"
28
- description: "Apply all transformations - rename, dead code removal, loop simplification, condition optimization, and function inlining"
29
- difficulty: hard
30
- reward_range: [0.0, 1.0]
31
- max_steps: 5
32
-
33
- observation_space:
34
- type: Box
35
- shape: [4]
36
- dtype: float32
37
- low: [0.0, 0.0, 0.0, 0.0]
38
- high: [inf, inf, inf, 1.0]
39
- fields:
40
- - code_length
41
- - complexity_score
42
- - runtime_s
43
- - error_flag
44
-
45
- action_space:
46
- type: Discrete
47
- n: 5
48
- actions:
49
- 0: rename_variable
50
- 1: remove_dead_code
51
- 2: simplify_loop
52
- 3: optimize_condition
53
- 4: inline_function
54
-
55
- api:
56
- health: "GET /"
57
- reset: "POST /reset"
58
- step: "POST /step"
59
- state: "GET /state"
60
- tasks: "GET /tasks"
61
- grade: "POST /tasks/{task_id}/grade"
62
-
63
- reward:
64
- raw_range: [-32, 20]
65
- normalized_range: [0.0, 1.0]
66
- formula: "(raw + 32) / 52"
67
- components:
68
- success: { max: 10, min: -10 }
69
- complexity: { max: 5, min: -5 }
70
- performance: { max: 5, min: -2 }
71
- error: { max: 0, min: -15 }
72
- no_change: { max: 0, min: -2 }
73
-
74
- validation:
75
- python_api:
76
- reset: "ObservationModel"
77
- step: "(ObservationModel, RewardModel, done, info)"
78
- state: "StateResponse"
79
- http_api:
80
- health: "GET /"
81
- reset: "POST /reset"
82
- step: "POST /step"
83
- state: "GET /state"
84
- tasks: "GET /tasks"
85
- grade: "POST /tasks/{task_id}/grade"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/openenv_interface.py DELETED
@@ -1,116 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any, Dict, Optional, Tuple
4
-
5
- try:
6
- from openenv.env import Env as OpenEnvBase
7
- except Exception: # pragma: no cover
8
- class OpenEnvBase:
9
- def __init__(self, *args: Any, **kwargs: Any) -> None:
10
- return None
11
-
12
- from acre.datasets.code_samples import CodeSample, CodeSampleDataset
13
- from acre.env.refactor_env import RefactorEnv
14
- from acre.tasks.task_registry import TaskRegistry
15
- from models import ActionModel, ObservationModel, RewardModel, StateResponse
16
-
17
-
18
- class OpenEnvRefactorEnv(OpenEnvBase):
19
- """
20
- Canonical OpenEnv interface for ACRE.
21
-
22
- This wrapper keeps the strict hackathon contract:
23
- - reset() -> ObservationModel
24
- - step(action) -> (ObservationModel, RewardModel, done, info)
25
- - state() -> StateResponse
26
- """
27
-
28
- def __init__(
29
- self,
30
- *,
31
- env: Optional[RefactorEnv] = None,
32
- registry: Optional[TaskRegistry] = None,
33
- ) -> None:
34
- super().__init__(
35
- name="ACRE",
36
- state_space="ObservationModel",
37
- action_space="ActionModel",
38
- episode_max_length=RefactorEnv.MAX_STEPS,
39
- )
40
- self._env = env or RefactorEnv()
41
- self._registry = registry or TaskRegistry()
42
- self._task_id: Optional[str] = None
43
- self._last_reset_info: Dict[str, Any] = {}
44
-
45
- @property
46
- def action_meanings(self) -> Dict[int, str]:
47
- return self._env.ACTION_MEANINGS
48
-
49
- @property
50
- def last_reset_info(self) -> Dict[str, Any]:
51
- return dict(self._last_reset_info)
52
-
53
- def _load_episode_source(self, *, task_id: Optional[str], code: Optional[str]) -> None:
54
- initial_code = code
55
- if initial_code is None and task_id:
56
- task = self._registry.get_task(task_id)
57
- if task is None:
58
- raise ValueError(f"Task '{task_id}' not found")
59
- initial_code = task.initial_code
60
-
61
- if initial_code is None:
62
- return None
63
-
64
- self._env.dataset = CodeSampleDataset(
65
- [
66
- CodeSample(
67
- id=task_id or "custom",
68
- language="python",
69
- code=initial_code,
70
- )
71
- ]
72
- )
73
- return None
74
-
75
- def reset(
76
- self,
77
- *,
78
- seed: Optional[int] = None,
79
- task_id: Optional[str] = None,
80
- code: Optional[str] = None,
81
- ) -> ObservationModel:
82
- self._task_id = task_id
83
- self._load_episode_source(task_id=task_id, code=code)
84
- observation, info = self._env.reset(seed=seed)
85
- self._last_reset_info = dict(info)
86
- return ObservationModel.from_vector(observation.tolist())
87
-
88
- def step(self, action: int | ActionModel) -> Tuple[ObservationModel, RewardModel, bool, Dict[str, Any]]:
89
- action_value = action.action if isinstance(action, ActionModel) else int(action)
90
- observation, raw_reward, terminated, truncated, info = self._env.step(action_value)
91
- reward = RewardModel(
92
- raw=float(raw_reward),
93
- normalized=float(info.get("normalized_reward", 0.0)),
94
- components=dict(info.get("reward_components", {})),
95
- )
96
- done = bool(terminated or truncated)
97
- return ObservationModel.from_vector(observation.tolist()), reward, done, dict(info)
98
-
99
- def state(self) -> StateResponse:
100
- raw_state = self._env.state()
101
- observation_vector = list(raw_state.get("observation", [0.0, 0.0, 0.0, 0.0]))
102
- observation = ObservationModel.from_vector(observation_vector)
103
- return StateResponse(
104
- current_code=str(raw_state.get("current_code", "")),
105
- episode_steps=int(raw_state.get("episode_steps", 0)),
106
- max_steps=int(raw_state.get("max_steps", RefactorEnv.MAX_STEPS)),
107
- complexity=float(raw_state.get("complexity", 0.0)),
108
- last_runtime=float(raw_state.get("last_runtime", 0.0)),
109
- last_error=bool(raw_state.get("last_error", False)),
110
- sample_id=raw_state.get("sample_id"),
111
- language=raw_state.get("language"),
112
- task_id=self._task_id,
113
- observation=observation,
114
- observation_vector=observation.to_vector(),
115
- action_meanings=dict(raw_state.get("action_meanings", {})),
116
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/requirements.txt DELETED
@@ -1,11 +0,0 @@
1
- fastapi>=0.109.0
2
- uvicorn[standard]>=0.27.0
3
- numpy>=1.26
4
- gymnasium
5
- stable-baselines3
6
- radon>=6.0.1
7
- openai>=1.0.0
8
- openenv>=0.1.13
9
- requests>=2.31.0
10
- pydantic>=2.0.0
11
- typing_extensions>=4.0.0
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/server.py DELETED
@@ -1,667 +0,0 @@
1
- """
2
- ACRE OpenEnv HTTP server.
3
-
4
- Endpoints (all required by OpenEnv spec):
5
- GET / — health check (must return HTTP 200)
6
- POST /reset — reset environment, returns observation + info
7
- POST /step — take one step, returns obs/reward/done/info
8
- GET /state — full current state snapshot
9
- GET /tasks — list all tasks with initial code
10
- POST /tasks/{task_id}/grade — grade code for a specific task
11
- """
12
- from __future__ import annotations
13
-
14
- import difflib
15
- import os
16
- import re
17
- import json
18
- from typing import Optional
19
-
20
- import uvicorn
21
- import numpy as np
22
- from fastapi import FastAPI, HTTPException
23
- from fastapi.middleware.cors import CORSMiddleware
24
- from fastapi.responses import HTMLResponse
25
- from openai import OpenAI
26
-
27
- try:
28
- from stable_baselines3 import PPO
29
- except Exception:
30
- PPO = None # type: ignore[assignment]
31
-
32
- from acre.tasks.task_registry import TaskRegistry
33
- from models import (
34
- ActionModel,
35
- CompatibilityHealthResponse,
36
- GradeRequest,
37
- GradeResponse,
38
- HealthResponse,
39
- OptimizationStep,
40
- OptimizeRequest,
41
- OptimizeResponse,
42
- ResetRequest,
43
- ResetResponse,
44
- StateResponse,
45
- StepRequest,
46
- StepResponse,
47
- TaskInfo,
48
- TasksResponse,
49
- )
50
- from openenv_interface import OpenEnvRefactorEnv
51
-
52
- DEFAULT_API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
53
- DEFAULT_MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
54
- DEFAULT_RL_MODEL_PATH = os.getenv("RL_MODEL_PATH", "acre_agent.zip")
55
-
56
- # ---------------------------------------------------------------------------
57
- # App setup
58
- # ---------------------------------------------------------------------------
59
-
60
- app = FastAPI(
61
- title="ACRE — Autonomous Code Refactoring Environment",
62
- description="OpenEnv-compatible RL environment for Python code refactoring.",
63
- version="1.0.0",
64
- )
65
-
66
- app.add_middleware(
67
- CORSMiddleware,
68
- allow_origins=["*"],
69
- allow_methods=["*"],
70
- allow_headers=["*"],
71
- )
72
-
73
- # Global singletons
74
- registry = TaskRegistry()
75
- _env: Optional[OpenEnvRefactorEnv] = None
76
- _rl_model_cache: dict[str, object] = {}
77
-
78
-
79
- def get_env() -> OpenEnvRefactorEnv:
80
- global _env
81
- if _env is None:
82
- _env = OpenEnvRefactorEnv(registry=registry)
83
- return _env
84
-
85
-
86
- def _state_response() -> StateResponse:
87
- return get_env().state()
88
-
89
-
90
- def _choose_action_heuristic(code: str, task_id: Optional[str]) -> int:
91
- has_generic = re.search(r"\b(x|tmp|i)\b", code) is not None
92
- has_if_false = re.search(r"\bif\s+False\b", code) is not None
93
- has_if_true = re.search(r"\bif\s+True\b", code) is not None
94
- has_append_loop = ".append(" in code and "for " in code
95
- has_double_not = "not not" in code
96
- has_add_call = "add(" in code
97
-
98
- if task_id == "rename_variables":
99
- if has_generic:
100
- return 0
101
- if has_if_false or "unused" in code:
102
- return 1
103
- if has_append_loop:
104
- return 2
105
- if has_if_true or has_double_not:
106
- return 3
107
- return 4
108
-
109
- if task_id == "remove_dead_code":
110
- if has_if_false or "unused" in code:
111
- return 1
112
- if has_append_loop:
113
- return 2
114
- if has_if_true or has_double_not:
115
- return 3
116
- if has_generic:
117
- return 0
118
- return 4
119
-
120
- if has_generic:
121
- return 0
122
- if has_append_loop:
123
- return 2
124
- if has_if_false or has_if_true or has_double_not:
125
- return 3
126
- if has_add_call:
127
- return 4
128
- return 1
129
-
130
-
131
- def _choose_action_llm(
132
- *,
133
- code: str,
134
- task_id: Optional[str],
135
- step_index: int,
136
- max_steps: int,
137
- api_base_url: str,
138
- model_name: str,
139
- api_token: str,
140
- ) -> tuple[int, str, str]:
141
- if not api_token.strip():
142
- return _choose_action_heuristic(code, task_id), "empty token -> heuristic", "heuristic"
143
-
144
- client = OpenAI(base_url=api_base_url, api_key=api_token)
145
- messages = [
146
- {
147
- "role": "system",
148
- "content": (
149
- "You are a code-refactoring action selector. Return ONLY compact JSON: "
150
- '{"action": <0-4>, "reason": "..."}.\n'
151
- "Actions: 0=rename_variable,1=remove_dead_code,2=simplify_loop,3=optimize_condition,4=inline_function"
152
- ),
153
- },
154
- {
155
- "role": "user",
156
- "content": (
157
- f"task_id={task_id or 'auto'}\n"
158
- f"step={step_index}/{max_steps}\n"
159
- "Current code:\n"
160
- f"```python\n{code}\n```"
161
- ),
162
- },
163
- ]
164
- try:
165
- resp = client.chat.completions.create(
166
- model=model_name,
167
- messages=messages,
168
- temperature=0.0,
169
- max_tokens=120,
170
- )
171
- raw = (resp.choices[0].message.content or "").strip()
172
- m = re.search(r"\{.*\}", raw, flags=re.DOTALL)
173
- blob = m.group(0) if m else raw
174
- parsed = json.loads(blob)
175
- action = int(parsed.get("action", -1))
176
- reason = str(parsed.get("reason", "llm-selected action"))
177
- if 0 <= action <= 4:
178
- return action, reason, "llm"
179
- except Exception as exc:
180
- return _choose_action_heuristic(code, task_id), f"llm error -> heuristic: {exc}", "heuristic"
181
-
182
- return _choose_action_heuristic(code, task_id), "invalid llm output -> heuristic", "heuristic"
183
-
184
-
185
- def _choose_action_rl(observation: list[float], model_path: str) -> tuple[Optional[int], str, str]:
186
- if PPO is None:
187
- return None, "stable-baselines3 unavailable", "rl"
188
- if not os.path.exists(model_path):
189
- return None, f"rl model not found: {model_path}", "rl"
190
-
191
- try:
192
- model = _rl_model_cache.get(model_path)
193
- if model is None:
194
- model = PPO.load(model_path)
195
- _rl_model_cache[model_path] = model
196
-
197
- obs = np.asarray(observation, dtype=np.float32)
198
- action, _ = model.predict(obs, deterministic=True)
199
- action_i = int(action)
200
- if 0 <= action_i <= 4:
201
- return action_i, "rl policy action", "rl"
202
- return None, f"invalid rl action: {action_i}", "rl"
203
- except Exception as exc:
204
- return None, f"rl failure: {exc}", "rl"
205
-
206
-
207
- def _demo_html() -> str:
208
- return """<!doctype html>
209
- <html lang=\"en\">
210
- <head>
211
- <meta charset=\"utf-8\" />
212
- <meta name=\"viewport\" content=\"width=device-width, initial-scale=1\" />
213
- <title>ACRE Refactor Demo</title>
214
- <style>
215
- @import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@400;600;700&display=swap');
216
- :root {
217
- --bg0: #0b1f2a;
218
- --bg1: #14344a;
219
- --ink: #eaf7ff;
220
- --muted: #a7c8db;
221
- --brand: #1ec28b;
222
- --warn: #ffcb47;
223
- --panel: rgba(8, 24, 36, 0.72);
224
- --stroke: rgba(140, 197, 225, 0.35);
225
- }
226
- * { box-sizing: border-box; }
227
- body {
228
- margin: 0;
229
- color: var(--ink);
230
- font-family: 'Space Grotesk', sans-serif;
231
- background:
232
- radial-gradient(circle at 12% 18%, rgba(30, 194, 139, 0.28), transparent 35%),
233
- radial-gradient(circle at 88% 8%, rgba(255, 203, 71, 0.22), transparent 30%),
234
- linear-gradient(150deg, var(--bg0), var(--bg1));
235
- min-height: 100vh;
236
- }
237
- .wrap {
238
- max-width: 1200px;
239
- margin: 0 auto;
240
- padding: 28px 20px 40px;
241
- }
242
- h1 {
243
- margin: 0 0 6px;
244
- font-size: clamp(1.6rem, 2vw + 1rem, 2.6rem);
245
- letter-spacing: 0.2px;
246
- }
247
- .sub { margin: 0 0 20px; color: var(--muted); }
248
- .grid {
249
- display: grid;
250
- grid-template-columns: 1fr;
251
- gap: 16px;
252
- }
253
- .panel {
254
- border: 1px solid var(--stroke);
255
- border-radius: 14px;
256
- background: var(--panel);
257
- backdrop-filter: blur(4px);
258
- padding: 14px;
259
- }
260
- .controls {
261
- display: grid;
262
- grid-template-columns: 1fr 1fr;
263
- gap: 8px;
264
- margin-bottom: 10px;
265
- }
266
- textarea, pre {
267
- width: 100%;
268
- min-height: 260px;
269
- border: 1px solid var(--stroke);
270
- border-radius: 10px;
271
- padding: 12px;
272
- background: rgba(1, 13, 24, 0.82);
273
- color: #dcf4ff;
274
- font-family: Consolas, 'Courier New', monospace;
275
- font-size: 13px;
276
- line-height: 1.4;
277
- overflow: auto;
278
- white-space: pre;
279
- }
280
- button, select {
281
- border: 1px solid var(--stroke);
282
- border-radius: 10px;
283
- padding: 10px 12px;
284
- background: rgba(11, 36, 52, 0.9);
285
- color: var(--ink);
286
- font-weight: 600;
287
- }
288
- button.primary {
289
- background: linear-gradient(120deg, #19a7ff, #1ec28b);
290
- color: #032235;
291
- border: none;
292
- }
293
- .cols {
294
- display: grid;
295
- grid-template-columns: 1fr;
296
- gap: 14px;
297
- }
298
- .meta {
299
- color: var(--muted);
300
- font-size: 0.92rem;
301
- margin-top: 8px;
302
- }
303
- .badge {
304
- color: #082b22;
305
- background: var(--brand);
306
- border-radius: 999px;
307
- padding: 2px 9px;
308
- font-size: 12px;
309
- font-weight: 700;
310
- }
311
- .warn {
312
- color: #2a1c00;
313
- background: var(--warn);
314
- }
315
- @media (min-width: 900px) {
316
- .cols { grid-template-columns: 1fr 1fr; }
317
- }
318
- </style>
319
- </head>
320
- <body>
321
- <div class=\"wrap\">
322
- <h1>ACRE Live Refactor Arena</h1>
323
- <p class=\"sub\">Paste old code, run the agent, and compare before and after with a full diff and step-by-step rewards.</p>
324
-
325
- <div class=\"panel\">
326
- <div class=\"controls\">
327
- <button onclick=\"loadExample(1)\">Load Example 1</button>
328
- <button onclick=\"loadExample(2)\">Load Example 2</button>
329
- <select id=\"task\">
330
- <option value=\"\">Auto strategy</option>
331
- <option value=\"rename_variables\">rename_variables</option>
332
- <option value=\"remove_dead_code\">remove_dead_code</option>
333
- <option value=\"full_refactor\">full_refactor</option>
334
- </select>
335
- <button class=\"primary\" onclick=\"runOptimize()\">Run Optimization</button>
336
- </div>
337
- <div class=\"controls\" style=\"margin-bottom: 10px;\">
338
- <select id=\"mode\">
339
- <option value=\"rl_then_llm\">RL First -> LLM Fallback</option>
340
- <option value=\"heuristic\">Heuristic Agent (no API key)</option>
341
- <option value=\"llm\">LLM Agent (OpenAI-compatible API)</option>
342
- </select>
343
- <input id=\"rlModelPath\" placeholder=\"RL model path\" value=\"acre_agent.zip\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
344
- <input id=\"baseUrl\" placeholder=\"API base URL (optional)\" value=\"https://api.openai.com/v1\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
345
- <input id=\"modelName\" placeholder=\"Model name (optional)\" value=\"gpt-4o-mini\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
346
- <input id=\"apiToken\" type=\"password\" placeholder=\"Paste API token here for LLM mode\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
347
- </div>
348
- <div class=\"controls\" style=\"margin-bottom: 10px;\">
349
- <label style=\"display:flex;align-items:center;gap:8px;padding:8px 10px;border:1px solid var(--stroke);border-radius:10px;\">
350
- <input id=\"autoSuggest\" type=\"checkbox\" />
351
- Auto suggest after typing pause
352
- </label>
353
- </div>
354
- <textarea id=\"input\" spellcheck=\"false\" placeholder=\"Paste your Python code here...\"></textarea>
355
- <p class=\"meta\" id=\"status\">Status: ready</p>
356
- </div>
357
-
358
- <div class=\"cols\" style=\"margin-top: 14px\">
359
- <div class=\"panel\">
360
- <h3>Original Code</h3>
361
- <pre id=\"original\"></pre>
362
- </div>
363
- <div class=\"panel\">
364
- <h3>Optimized Code</h3>
365
- <pre id=\"optimized\"></pre>
366
- </div>
367
- </div>
368
-
369
- <div class=\"panel\" style=\"margin-top: 14px\">
370
- <h3>Diff</h3>
371
- <pre id=\"diff\"></pre>
372
- </div>
373
-
374
- <div class=\"panel\" style=\"margin-top: 14px\">
375
- <h3>Step Logs</h3>
376
- <pre id=\"steps\"></pre>
377
- </div>
378
- </div>
379
-
380
- <script>
381
- const EX1 = `def compute(x, y, tmp):\n tmp = x + y\n x = tmp * 2\n result = x\n return result\n`;
382
- const EX2 = `def add(p, q):\n return p + q\n\ndef compute(x, data, tmp):\n result = []\n for item in data:\n result.append(item * 2)\n if False:\n y = 999\n if True:\n val = add(x, tmp)\n unused = 0\n flag = not not True\n return val\n print(\"dead\")\n`;
383
- let autoTimer = null;
384
-
385
- function loadExample(i) {
386
- document.getElementById('input').value = i === 1 ? EX1 : EX2;
387
- document.getElementById('status').textContent = `Status: loaded example ${i}`;
388
- }
389
-
390
- async function runOptimize() {
391
- const code = document.getElementById('input').value;
392
- const task = document.getElementById('task').value || null;
393
- const mode = document.getElementById('mode').value;
394
- const useRl = mode === 'rl_then_llm';
395
- const useLlm = mode === 'llm' || mode === 'rl_then_llm';
396
- const fallbackToLlm = mode === 'rl_then_llm';
397
- const rlModelPath = document.getElementById('rlModelPath').value || null;
398
- const apiToken = document.getElementById('apiToken').value || null;
399
- const apiBaseUrl = document.getElementById('baseUrl').value || null;
400
- const modelName = document.getElementById('modelName').value || null;
401
- if (!code.trim()) {
402
- document.getElementById('status').innerHTML = 'Status: <span class=\"badge warn\">please paste code first</span>';
403
- return;
404
- }
405
- if (mode === 'llm' && (!apiToken || !apiToken.trim())) {
406
- document.getElementById('status').innerHTML = 'Status: <span class=\"badge warn\">paste API token for LLM mode</span>';
407
- return;
408
- }
409
-
410
- document.getElementById('status').textContent = 'Status: running optimization...';
411
- try {
412
- const res = await fetch('/optimize', {
413
- method: 'POST',
414
- headers: {'Content-Type': 'application/json'},
415
- body: JSON.stringify({
416
- code,
417
- task_id: task,
418
- max_steps: 5,
419
- use_rl: useRl,
420
- use_llm: useLlm,
421
- fallback_to_llm: fallbackToLlm,
422
- rl_model_path: rlModelPath,
423
- api_base_url: apiBaseUrl,
424
- model_name: modelName,
425
- api_token: apiToken,
426
- })
427
- });
428
- const data = await res.json();
429
- if (!res.ok) {
430
- throw new Error(data.detail || 'request failed');
431
- }
432
-
433
- document.getElementById('original').textContent = data.original_code;
434
- document.getElementById('optimized').textContent = data.optimized_code;
435
- document.getElementById('diff').textContent = data.diff || '(no diff)';
436
- document.getElementById('steps').textContent = JSON.stringify(data.steps, null, 2);
437
-
438
- const scoreText = data.task_score === null ? 'n/a' : data.task_score;
439
- document.getElementById('status').innerHTML = `Status: <span class=\"badge\">done</span> cumulative_reward=${data.cumulative_reward.toFixed(2)} task_score=${scoreText}`;
440
- } catch (err) {
441
- document.getElementById('status').innerHTML = `Status: <span class=\"badge warn\">error</span> ${err.message}`;
442
- }
443
- }
444
-
445
- loadExample(1);
446
- document.getElementById('input').addEventListener('input', () => {
447
- if (!document.getElementById('autoSuggest').checked) {
448
- return;
449
- }
450
- if (autoTimer) {
451
- clearTimeout(autoTimer);
452
- }
453
- autoTimer = setTimeout(() => {
454
- runOptimize();
455
- }, 1200);
456
- });
457
- </script>
458
- </body>
459
- </html>"""
460
-
461
-
462
- # ---------------------------------------------------------------------------
463
- # Routes
464
- # ---------------------------------------------------------------------------
465
-
466
- @app.get("/", response_model=HealthResponse)
467
- def health() -> HealthResponse:
468
- """Health check — OpenEnv pings this URL to verify the Space is live."""
469
- return HealthResponse(status="ok", env="ACRE", version="1.0.0")
470
-
471
-
472
- @app.get("/health", response_model=CompatibilityHealthResponse)
473
- def health_compat() -> CompatibilityHealthResponse:
474
- """Compatibility health route used by some OpenEnv reference environments."""
475
- return CompatibilityHealthResponse(status="healthy", service="acre-env")
476
-
477
-
478
- @app.get("/demo", response_class=HTMLResponse)
479
- def demo_ui() -> HTMLResponse:
480
- """Simple UI to compare original and optimized code side-by-side."""
481
- return HTMLResponse(content=_demo_html())
482
-
483
-
484
- @app.post("/reset", response_model=ResetResponse)
485
- def reset(req: ResetRequest = ResetRequest()) -> ResetResponse:
486
- """Reset the environment. Optionally load a task's initial code."""
487
- env = get_env()
488
- try:
489
- obs = env.reset(seed=req.seed, task_id=req.task_id, code=req.code)
490
- except ValueError as exc:
491
- raise HTTPException(status_code=404, detail=str(exc)) from exc
492
- return ResetResponse(
493
- observation=obs,
494
- observation_vector=obs.to_vector(),
495
- info=env.last_reset_info,
496
- task_id=req.task_id,
497
- state=_state_response(),
498
- )
499
-
500
-
501
- @app.post("/step", response_model=StepResponse)
502
- def step(req: StepRequest) -> StepResponse:
503
- """Take one refactoring step."""
504
- env = get_env()
505
- if not (0 <= req.action <= 4):
506
- raise HTTPException(status_code=400, detail="action must be 0–4")
507
-
508
- obs, reward, done, info = env.step(req.action)
509
- action_name = str(info.get("action_name", env.action_meanings.get(req.action, "unknown")))
510
-
511
- return StepResponse(
512
- action=ActionModel(action=req.action, action_name=action_name),
513
- observation=obs,
514
- observation_vector=obs.to_vector(),
515
- reward=reward,
516
- done=done,
517
- terminated=done,
518
- truncated=False,
519
- info=info,
520
- state=_state_response(),
521
- )
522
-
523
-
524
- @app.get("/state", response_model=StateResponse)
525
- def state() -> StateResponse:
526
- """Return full current environment state (OpenEnv spec requirement)."""
527
- return _state_response()
528
-
529
-
530
- @app.get("/tasks", response_model=TasksResponse)
531
- def list_tasks() -> TasksResponse:
532
- """Enumerate all tasks (easy → medium → hard)."""
533
- return TasksResponse(tasks=[TaskInfo.model_validate(t) for t in registry.list_tasks()])
534
-
535
-
536
- @app.post("/tasks/{task_id}/grade", response_model=GradeResponse)
537
- def grade(task_id: str, req: GradeRequest) -> GradeResponse:
538
- """Grade submitted code against a task's grader (returns score 0.0–1.0)."""
539
- task = registry.get_task(task_id)
540
- if task is None:
541
- raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found")
542
- score = task.grade(req.code)
543
- return GradeResponse(
544
- task_id=task_id,
545
- score=round(score, 4),
546
- passed=score >= 0.8,
547
- )
548
-
549
-
550
- @app.post("/optimize", response_model=OptimizeResponse)
551
- def optimize(req: OptimizeRequest) -> OptimizeResponse:
552
- """Run a full optimization episode and return code comparison artifacts."""
553
- code = req.code.strip("\n")
554
- if not code.strip():
555
- raise HTTPException(status_code=400, detail="code must be non-empty")
556
-
557
- env = get_env()
558
- try:
559
- env.reset(task_id=req.task_id, code=code)
560
- except ValueError as exc:
561
- raise HTTPException(status_code=404, detail=str(exc)) from exc
562
-
563
- steps: list[OptimizationStep] = []
564
- cumulative_reward = 0.0
565
-
566
- for step_idx in range(1, req.max_steps + 1):
567
- state_now = env.state()
568
- current_code = state_now.current_code
569
- obs_list = [float(x) for x in state_now.observation_vector]
570
-
571
- action: int
572
- reason: str
573
- source: str
574
-
575
- if req.use_rl:
576
- rl_action, rl_reason, rl_source = _choose_action_rl(
577
- observation=obs_list,
578
- model_path=req.rl_model_path or DEFAULT_RL_MODEL_PATH,
579
- )
580
- if rl_action is not None:
581
- action, reason, source = rl_action, rl_reason, rl_source
582
- elif req.fallback_to_llm and req.use_llm:
583
- action, reason, source = _choose_action_llm(
584
- code=current_code,
585
- task_id=req.task_id,
586
- step_index=step_idx,
587
- max_steps=req.max_steps,
588
- api_base_url=req.api_base_url or DEFAULT_API_BASE_URL,
589
- model_name=req.model_name or DEFAULT_MODEL_NAME,
590
- api_token=req.api_token or "",
591
- )
592
- reason = f"{rl_reason}; {reason}"
593
- else:
594
- action = _choose_action_heuristic(current_code, req.task_id)
595
- reason = f"{rl_reason}; heuristic fallback"
596
- source = "heuristic"
597
- elif req.use_llm:
598
- action, reason, source = _choose_action_llm(
599
- code=current_code,
600
- task_id=req.task_id,
601
- step_index=step_idx,
602
- max_steps=req.max_steps,
603
- api_base_url=req.api_base_url or DEFAULT_API_BASE_URL,
604
- model_name=req.model_name or DEFAULT_MODEL_NAME,
605
- api_token=req.api_token or "",
606
- )
607
- else:
608
- action = _choose_action_heuristic(current_code, req.task_id)
609
- reason = "heuristic policy"
610
- source = "heuristic"
611
-
612
- _, reward, done, info = env.step(action)
613
- state_now = env.state()
614
-
615
- cumulative_reward += float(reward.raw)
616
- steps.append(
617
- OptimizationStep(
618
- step=step_idx,
619
- action=action,
620
- action_name=info.get("action_name", "unknown"),
621
- reason=reason,
622
- source=source,
623
- reward=float(reward.raw),
624
- normalized_reward=float(reward.normalized),
625
- changed=bool(info.get("changed", False)),
626
- complexity=float(state_now.complexity),
627
- )
628
- )
629
-
630
- if done:
631
- break
632
-
633
- final_code = str(env.state().current_code)
634
- diff_lines = difflib.unified_diff(
635
- code.splitlines(),
636
- final_code.splitlines(),
637
- fromfile="original.py",
638
- tofile="optimized.py",
639
- lineterm="",
640
- )
641
- diff_text = "\n".join(diff_lines)
642
-
643
- task_score: Optional[float] = None
644
- if req.task_id:
645
- task = registry.get_task(req.task_id)
646
- if task is None:
647
- raise HTTPException(status_code=404, detail=f"Task '{req.task_id}' not found")
648
- task_score = round(task.grade(final_code), 4)
649
-
650
- return OptimizeResponse(
651
- original_code=code,
652
- optimized_code=final_code,
653
- diff=diff_text,
654
- steps=steps,
655
- cumulative_reward=round(cumulative_reward, 4),
656
- task_id=req.task_id,
657
- task_score=task_score,
658
- )
659
-
660
-
661
- # ---------------------------------------------------------------------------
662
- # Entry point
663
- # ---------------------------------------------------------------------------
664
-
665
- if __name__ == "__main__":
666
- port = int(os.getenv("PORT", 7860))
667
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ACRE_FINAL/validate.py DELETED
@@ -1,281 +0,0 @@
1
- """
2
- ACRE pre-submission validator.
3
-
4
- Checks the repository against the submission checklist and, when a server URL is
5
- available, probes the HTTP API as well.
6
-
7
- Run:
8
- python validate.py --url http://localhost:7860
9
- """
10
- from __future__ import annotations
11
-
12
- import argparse
13
- import ast
14
- import re
15
- import sys
16
- from typing import Any, Tuple
17
-
18
- try:
19
- import requests
20
- except ImportError:
21
- print("[ERROR] requests is required. Run: pip install requests")
22
- sys.exit(1)
23
-
24
- PASS = "\033[92m[PASS]\033[0m"
25
- FAIL = "\033[91m[FAIL]\033[0m"
26
-
27
-
28
- def check(label: str, ok: bool, detail: str = "") -> bool:
29
- status = PASS if ok else FAIL
30
- message = f" {status} {label}"
31
- if detail:
32
- message += f" - {detail}"
33
- print(message)
34
- return ok
35
-
36
-
37
- def get(url: str, path: str, timeout: int = 15) -> Tuple[bool, Any]:
38
- try:
39
- response = requests.get(f"{url}{path}", timeout=timeout)
40
- response.raise_for_status()
41
- return True, response.json()
42
- except Exception as exc:
43
- return False, str(exc)
44
-
45
-
46
- def post(url: str, path: str, payload: dict, timeout: int = 15) -> Tuple[bool, Any]:
47
- try:
48
- response = requests.post(f"{url}{path}", json=payload, timeout=timeout)
49
- response.raise_for_status()
50
- return True, response.json()
51
- except Exception as exc:
52
- return False, str(exc)
53
-
54
-
55
- def read_text(path: str) -> str:
56
- with open(path, encoding="utf-8") as handle:
57
- return handle.read()
58
-
59
-
60
- def run_validation(base_url: str) -> int:
61
- failures = 0
62
-
63
- print("\n" + "=" * 60)
64
- print(" ACRE Pre-Submission Validator")
65
- print("=" * 60)
66
- print(f" Target: {base_url}\n")
67
-
68
- print("1. Static repository checks")
69
- try:
70
- interface_src = read_text("openenv_interface.py")
71
- tree = ast.parse(interface_src)
72
- classes = {node.name: node for node in tree.body if isinstance(node, ast.ClassDef)}
73
- env_cls = classes.get("OpenEnvRefactorEnv")
74
- failures += 0 if check("openenv_interface.py exists", True) else 1
75
- failures += 0 if check("OpenEnvRefactorEnv is defined", env_cls is not None) else 1
76
- if env_cls is not None:
77
- methods = {node.name for node in env_cls.body if isinstance(node, ast.FunctionDef)}
78
- for method_name in ["reset", "step", "state"]:
79
- failures += 0 if check(
80
- f"OpenEnvRefactorEnv implements {method_name}()",
81
- method_name in methods,
82
- ) else 1
83
- except FileNotFoundError:
84
- failures += 1
85
- check("openenv_interface.py exists", False, "file not found")
86
-
87
- try:
88
- models_src = read_text("models.py")
89
- for name in ["ObservationModel", "ActionModel", "RewardModel"]:
90
- failures += 0 if check(
91
- f"{name} is defined in models.py",
92
- f"class {name}" in models_src,
93
- ) else 1
94
- except FileNotFoundError:
95
- failures += 1
96
- check("models.py exists", False, "file not found")
97
-
98
- print("\n2. Health check (GET /)")
99
- ok, data = get(base_url, "/")
100
- failures += 0 if check("GET / returns HTTP 200", ok) else 1
101
- if ok:
102
- failures += 0 if check(
103
- "Response has status field",
104
- isinstance(data, dict) and "status" in data,
105
- str(data),
106
- ) else 1
107
-
108
- print("\n3. Tasks (GET /tasks)")
109
- ok, data = get(base_url, "/tasks")
110
- failures += 0 if check("GET /tasks returns 200", ok) else 1
111
- if ok:
112
- tasks = data.get("tasks", []) if isinstance(data, dict) else []
113
- failures += 0 if check("At least 3 tasks defined", len(tasks) >= 3, f"found {len(tasks)}") else 1
114
- difficulties = [t.get("difficulty", "") for t in tasks]
115
- for diff in ["easy", "medium", "hard"]:
116
- failures += 0 if check(f"Task with difficulty '{diff}' exists", diff in difficulties) else 1
117
- for task in tasks:
118
- failures += 0 if check(
119
- f"Task '{task.get('id')}' has initial_code",
120
- bool(task.get("initial_code")),
121
- ) else 1
122
-
123
- print("\n4. Reset (POST /reset)")
124
- ok, data = post(base_url, "/reset", {})
125
- failures += 0 if check("POST /reset returns 200", ok) else 1
126
- if ok:
127
- observation = data.get("observation", {})
128
- failures += 0 if check("Response has observation field", isinstance(observation, dict)) else 1
129
- failures += 0 if check(
130
- "Observation is typed with 4 fields",
131
- {"code_length", "complexity_score", "runtime_s", "error_flag"}.issubset(observation),
132
- str(observation),
133
- ) else 1
134
-
135
- ok, _ = post(base_url, "/reset", {"task_id": "rename_variables"})
136
- failures += 0 if check("POST /reset with task_id works", ok) else 1
137
-
138
- print("\n5. State (GET /state)")
139
- ok, data = get(base_url, "/state")
140
- failures += 0 if check("GET /state returns 200", ok) else 1
141
- if ok:
142
- required_keys = [
143
- "current_code",
144
- "episode_steps",
145
- "max_steps",
146
- "complexity",
147
- "observation",
148
- "observation_vector",
149
- "action_meanings",
150
- ]
151
- for key in required_keys:
152
- failures += 0 if check(f"State has '{key}' field", key in data) else 1
153
-
154
- print("\n6. Step (POST /step)")
155
- post(base_url, "/reset", {"task_id": "rename_variables"})
156
- for action in range(5):
157
- ok, data = post(base_url, "/step", {"action": action})
158
- failures += 0 if check(
159
- f"Action {action} executes without error",
160
- ok and isinstance(data, dict) and "reward" in data and "done" in data,
161
- ) else 1
162
- if ok:
163
- reward_payload = data.get("reward", {})
164
- norm = reward_payload.get("normalized", -1)
165
- failures += 0 if check(
166
- f"Action {action} returns typed reward payload",
167
- {"raw", "normalized", "components"}.issubset(reward_payload),
168
- str(reward_payload),
169
- ) else 1
170
- failures += 0 if check(
171
- f"Action {action} normalized_reward in [0,1]",
172
- isinstance(norm, (int, float)) and 0.0 <= float(norm) <= 1.0,
173
- f"got {norm}",
174
- ) else 1
175
- if data.get("done"):
176
- break
177
-
178
- ok, data = post(base_url, "/step", {"action": 99})
179
- check("Invalid action returns error (not crash)", not ok or "detail" in str(data), "(expected 4xx)")
180
-
181
- print("\n7. Task graders (POST /tasks/{id}/grade)")
182
- for task_id in ["rename_variables", "remove_dead_code", "full_refactor"]:
183
- ok, data = post(base_url, f"/tasks/{task_id}/grade", {"code": "def f(): pass"})
184
- failures += 0 if check(f"Grade endpoint for '{task_id}' works", ok) else 1
185
- if ok:
186
- score = data.get("score", -1)
187
- failures += 0 if check(
188
- f"Score for '{task_id}' in [0.0, 1.0]",
189
- isinstance(score, (int, float)) and 0.0 <= float(score) <= 1.0,
190
- f"got {score}",
191
- ) else 1
192
-
193
- print("\n8. openenv.yaml")
194
- try:
195
- openenv_yaml = read_text("openenv.yaml")
196
- failures += 0 if check("openenv.yaml exists", True) else 1
197
- for field in ["tasks:", "action_space:", "observation_space:", "reward:", "entrypoint:", "validation:"]:
198
- failures += 0 if check(f"openenv.yaml has '{field}' section", field in openenv_yaml) else 1
199
- except FileNotFoundError:
200
- failures += 1
201
- check("openenv.yaml exists", False, "file not found")
202
-
203
- print("\n9. inference.py")
204
- try:
205
- inference_src = read_text("inference.py")
206
- failures += 0 if check("inference.py exists", True) else 1
207
- for marker in ['"event": "START"', '"event": "STEP"', '"event": "END"']:
208
- failures += 0 if check(f"inference.py emits {marker}", marker in inference_src) else 1
209
- failures += 0 if check(
210
- "Uses OpenAI client",
211
- "from openai import OpenAI" in inference_src,
212
- ) else 1
213
- for var in ["API_BASE_URL", "MODEL_NAME", "HF_TOKEN", "ENV_URL", "LOCAL_IMAGE_NAME"]:
214
- failures += 0 if check(f"inference.py reads {var} from env", var in inference_src) else 1
215
- failures += 0 if check(
216
- "API_BASE_URL has a default",
217
- 'os.getenv("API_BASE_URL", "https://api.openai.com/v1")' in inference_src,
218
- ) else 1
219
- failures += 0 if check(
220
- "MODEL_NAME has a default",
221
- 'os.getenv("MODEL_NAME", "gpt-4o-mini")' in inference_src,
222
- ) else 1
223
- failures += 0 if check(
224
- "HF_TOKEN has no default",
225
- re.search(r'HF_TOKEN\s*:\s*.*os\.getenv\("HF_TOKEN"\)', inference_src) is not None,
226
- ) else 1
227
- except FileNotFoundError:
228
- failures += 1
229
- check("inference.py exists", False, "file not found")
230
-
231
- print("\n10. Dockerfile")
232
- try:
233
- dockerfile = read_text("Dockerfile")
234
- failures += 0 if check("Dockerfile exists", True) else 1
235
- failures += 0 if check("Exposes port 7860", "7860" in dockerfile) else 1
236
- failures += 0 if check("Has CMD/ENTRYPOINT", "CMD" in dockerfile or "ENTRYPOINT" in dockerfile) else 1
237
- failures += 0 if check("Does not set a default HF_TOKEN", "ENV HF_TOKEN" not in dockerfile) else 1
238
- except FileNotFoundError:
239
- failures += 1
240
- check("Dockerfile exists", False, "file not found")
241
-
242
- print("\n11. README / Hugging Face metadata")
243
- try:
244
- readme = read_text("README.md")
245
- failures += 0 if check("README has docker SDK front matter", "sdk: docker" in readme) else 1
246
- failures += 0 if check("README includes openenv tag", "openenv" in readme) else 1
247
- for section in [
248
- "Environment Overview and Motivation",
249
- "Definitions of Action and Observation Spaces",
250
- "Task Descriptions with Expected Difficulty Levels",
251
- "Setup and Usage Instructions",
252
- "Baseline Performance Scores",
253
- ]:
254
- failures += 0 if check(f"README includes '{section}'", section in readme) else 1
255
- except FileNotFoundError:
256
- failures += 1
257
- check("README.md exists", False, "file not found")
258
-
259
- print("\n" + "=" * 60)
260
- if failures == 0:
261
- print(f" {PASS} All checks passed. Repository is submission-ready.")
262
- else:
263
- print(f" {FAIL} {failures} check(s) failed. Fix before submitting.")
264
- print("=" * 60 + "\n")
265
-
266
- return failures
267
-
268
-
269
- def main() -> None:
270
- parser = argparse.ArgumentParser(description="ACRE pre-submission validator")
271
- parser.add_argument(
272
- "--url",
273
- default="http://localhost:7860",
274
- help="Base URL of the running ACRE server",
275
- )
276
- args = parser.parse_args()
277
- sys.exit(run_validation(args.url))
278
-
279
-
280
- if __name__ == "__main__":
281
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -167,9 +167,9 @@ The deterministic fallback policy used by `inference.py` produces the following
167
 
168
  | Task | Score |
169
  |---|---|
170
- | `rename_variables` | 1.0 |
171
- | `remove_dead_code` | 1.0 |
172
- | `full_refactor` | 1.0 |
173
- | Average | 1.0 |
174
 
175
  These scores come from the built-in heuristic policy with `HF_TOKEN` unset, which keeps the baseline reproducible across runs.
 
167
 
168
  | Task | Score |
169
  |---|---|
170
+ | `rename_variables` | 1.0000 |
171
+ | `remove_dead_code` | 0.2500 |
172
+ | `full_refactor` | 0.7143 |
173
+ | Average | 0.6548 |
174
 
175
  These scores come from the built-in heuristic policy with `HF_TOKEN` unset, which keeps the baseline reproducible across runs.
acre/tasks/task_registry.py CHANGED
@@ -5,7 +5,7 @@ from __future__ import annotations
5
 
6
  import ast
7
  from dataclasses import dataclass
8
- from typing import Callable, Dict, List, Optional
9
 
10
 
11
  @dataclass
@@ -14,9 +14,13 @@ class Task:
14
  name: str
15
  description: str
16
  difficulty: str
17
- initial_code: str
18
  _grade_fn: Callable[[str], float]
19
 
 
 
 
 
20
  def grade(self, code: str) -> float:
21
  """Return a score in [0.0, 1.0]."""
22
  try:
@@ -25,21 +29,90 @@ class Task:
25
  return 0.0
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # ---------------------------------------------------------------------------
29
  # Task 1 — Easy: Rename generic variables
30
  # ---------------------------------------------------------------------------
31
- _EASY_CODE = """\
 
32
  def compute(x, y, tmp):
33
  tmp = x + y
34
  x = tmp * 2
35
  result = x
36
  return result
37
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  def _grade_easy(code: str) -> float:
41
- """Score = fraction of generic names (x, tmp) removed from all scopes."""
42
- generic = {"x", "tmp"}
43
  try:
44
  tree = ast.parse(code)
45
  except SyntaxError:
@@ -66,7 +139,8 @@ def _grade_easy(code: str) -> float:
66
  # ---------------------------------------------------------------------------
67
  # Task 2 — Medium: Remove dead code
68
  # ---------------------------------------------------------------------------
69
- _MEDIUM_CODE = """\
 
70
  def process(data):
71
  result = []
72
  for item in data:
@@ -76,31 +150,74 @@ def process(data):
76
  unused_var = 42
77
  return result
78
  print("unreachable")
79
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  def _grade_medium(code: str) -> float:
83
- """Score = fraction of dead-code patterns eliminated (3 checks, ~0.33 each)."""
84
  try:
85
  tree = ast.parse(code)
86
  except SyntaxError:
87
  return 0.0
88
 
89
- source = ast.unparse(tree)
90
  score = 0.0
91
 
92
- # Check 1: if-False block removed
93
- if "if False" not in source:
94
- score += 1 / 3
95
 
96
- # Check 2: unused_var assignment removed
97
- if "unused_var" not in source:
98
- score += 1 / 3
99
 
100
  # Check 3: list comprehension used (loop simplified)
101
  has_listcomp = any(isinstance(n, ast.ListComp) for n in ast.walk(tree))
102
  if has_listcomp:
103
- score += 1 / 3
 
 
 
 
104
 
105
  return score
106
 
@@ -108,7 +225,8 @@ def _grade_medium(code: str) -> float:
108
  # ---------------------------------------------------------------------------
109
  # Task 3 — Hard: Full refactor
110
  # ---------------------------------------------------------------------------
111
- _HARD_CODE = """\
 
112
  def add(p, q):
113
  return p + q
114
 
@@ -124,34 +242,89 @@ def compute(x, data, tmp):
124
  flag = not not True
125
  return val
126
  print("dead")
127
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
  def _grade_hard(code: str) -> float:
131
- """Score = fraction of 5 quality checks passed."""
132
  try:
133
  tree = ast.parse(code)
134
  except SyntaxError:
135
  return 0.0
136
 
137
- source = ast.unparse(tree)
138
  checks = 0
139
 
140
- # 1. No generic variable names x/tmp in function signature or body
141
  has_generic = False
142
 
143
  class _GenCheck(ast.NodeVisitor):
144
  def visit_arg(self, node: ast.arg) -> None:
145
  nonlocal has_generic
146
- if node.arg in {"x", "tmp"}:
147
  has_generic = True
148
 
149
  _GenCheck().visit(tree)
150
  if not has_generic:
151
  checks += 1
152
 
153
- # 2. No if False block
154
- if "if False" not in source:
155
  checks += 1
156
 
157
  # 3. if True removed (body inlined)
@@ -162,13 +335,21 @@ def _grade_hard(code: str) -> float:
162
  if any(isinstance(n, ast.ListComp) for n in ast.walk(tree)):
163
  checks += 1
164
 
165
- # 5. add() call inlined (no call to 'add')
166
  calls = [n for n in ast.walk(tree) if isinstance(n, ast.Call)]
167
  fn_names = {c.func.id for c in calls if isinstance(c.func, ast.Name)}
168
- if "add" not in fn_names:
 
 
 
 
 
 
 
 
169
  checks += 1
170
 
171
- return checks / 5
172
 
173
 
174
  # ---------------------------------------------------------------------------
@@ -186,7 +367,7 @@ class TaskRegistry:
186
  name="Rename Variables (Easy)",
187
  description="Rename generic variable names (x, tmp) to descriptive ones",
188
  difficulty="easy",
189
- initial_code=_EASY_CODE,
190
  _grade_fn=_grade_easy,
191
  )
192
  self._tasks["remove_dead_code"] = Task(
@@ -194,7 +375,7 @@ class TaskRegistry:
194
  name="Remove Dead Code (Medium)",
195
  description="Remove unreachable code, if False blocks, and unused variables",
196
  difficulty="medium",
197
- initial_code=_MEDIUM_CODE,
198
  _grade_fn=_grade_medium,
199
  )
200
  self._tasks["full_refactor"] = Task(
@@ -202,7 +383,7 @@ class TaskRegistry:
202
  name="Full Refactor (Hard)",
203
  description="Apply all transformations: rename, dead code, loops, conditions, inlining",
204
  difficulty="hard",
205
- initial_code=_HARD_CODE,
206
  _grade_fn=_grade_hard,
207
  )
208
 
 
5
 
6
  import ast
7
  from dataclasses import dataclass
8
+ from typing import Callable, Dict, List, Optional, Sequence
9
 
10
 
11
  @dataclass
 
14
  name: str
15
  description: str
16
  difficulty: str
17
+ samples: List[str]
18
  _grade_fn: Callable[[str], float]
19
 
20
+ @property
21
+ def initial_code(self) -> str:
22
+ return str(self.samples[0]) if self.samples else ""
23
+
24
  def grade(self, code: str) -> float:
25
  """Return a score in [0.0, 1.0]."""
26
  try:
 
29
  return 0.0
30
 
31
 
32
+ def _safe_unparse(tree: ast.AST) -> str:
33
+ try:
34
+ return ast.unparse(tree)
35
+ except Exception:
36
+ return ""
37
+
38
+
39
+ def _has_unreachable_after_terminator(stmts: Sequence[ast.stmt]) -> bool:
40
+ unreachable = False
41
+ for s in stmts:
42
+ if unreachable:
43
+ # ignore empty docstrings as "unreachable" noise
44
+ if isinstance(s, ast.Expr) and isinstance(s.value, ast.Constant) and isinstance(s.value.value, str):
45
+ continue
46
+ return True
47
+ if isinstance(s, (ast.Return, ast.Raise)):
48
+ unreachable = True
49
+ return False
50
+
51
+
52
+ def _tree_has_unreachable(tree: ast.AST) -> bool:
53
+ class _Scan(ast.NodeVisitor):
54
+ def __init__(self) -> None:
55
+ self.bad = False
56
+
57
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # noqa: N802
58
+ if _has_unreachable_after_terminator(node.body):
59
+ self.bad = True
60
+ self.generic_visit(node)
61
+
62
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: # noqa: N802
63
+ if _has_unreachable_after_terminator(node.body):
64
+ self.bad = True
65
+ self.generic_visit(node)
66
+
67
+ s = _Scan()
68
+ s.visit(tree)
69
+ return bool(s.bad)
70
+
71
+
72
  # ---------------------------------------------------------------------------
73
  # Task 1 — Easy: Rename generic variables
74
  # ---------------------------------------------------------------------------
75
+ _EASY_SAMPLES: List[str] = [
76
+ """\
77
  def compute(x, y, tmp):
78
  tmp = x + y
79
  x = tmp * 2
80
  result = x
81
  return result
82
+ """,
83
+ """\
84
+ def normalize(tmp, x):
85
+ for i in range(3):
86
+ tmp = tmp + i
87
+ return tmp * x
88
+ """,
89
+ """\
90
+ def score(items):
91
+ tmp = 0
92
+ for i in items:
93
+ tmp += i
94
+ x = tmp
95
+ return x
96
+ """,
97
+ """\
98
+ def transform(x):
99
+ tmp = x
100
+ if tmp > 10:
101
+ tmp = tmp - 1
102
+ return tmp
103
+ """,
104
+ """\
105
+ def merge(a, b):
106
+ x = a
107
+ tmp = b
108
+ return x + tmp
109
+ """,
110
+ ]
111
 
112
 
113
  def _grade_easy(code: str) -> float:
114
+ """Score = fraction of generic names removed from all scopes."""
115
+ generic = {"x", "tmp", "i"}
116
  try:
117
  tree = ast.parse(code)
118
  except SyntaxError:
 
139
  # ---------------------------------------------------------------------------
140
  # Task 2 — Medium: Remove dead code
141
  # ---------------------------------------------------------------------------
142
+ _MEDIUM_SAMPLES: List[str] = [
143
+ """\
144
  def process(data):
145
  result = []
146
  for item in data:
 
150
  unused_var = 42
151
  return result
152
  print("unreachable")
153
+ """,
154
+ """\
155
+ def build(values):
156
+ out = []
157
+ for v in values:
158
+ out.append(v + 1)
159
+ while False:
160
+ out.append(999)
161
+ dead = 0
162
+ return out
163
+ dead += 1
164
+ """,
165
+ """\
166
+ def route(flag):
167
+ if False:
168
+ return 1
169
+ if True:
170
+ x = 2
171
+ y = x
172
+ return y
173
+ """,
174
+ """\
175
+ def clean(xs):
176
+ res = []
177
+ for x in xs:
178
+ res.append(x * 2)
179
+ unused = "remove me"
180
+ if False:
181
+ unused2 = 123
182
+ return res
183
+ """,
184
+ """\
185
+ def calc(n):
186
+ total = 0
187
+ for i in range(n):
188
+ total += i
189
+ return total
190
+ print("dead")
191
+ """,
192
+ ]
193
 
194
 
195
  def _grade_medium(code: str) -> float:
196
+ """Score = fraction of dead-code patterns eliminated (4 checks, 0.25 each)."""
197
  try:
198
  tree = ast.parse(code)
199
  except SyntaxError:
200
  return 0.0
201
 
202
+ source = _safe_unparse(tree)
203
  score = 0.0
204
 
205
+ # Check 1: if/while-False removed
206
+ if ("if False" not in source) and ("while False" not in source):
207
+ score += 0.25
208
 
209
+ # Check 2: no unreachable statements after return/raise
210
+ if not _tree_has_unreachable(tree):
211
+ score += 0.25
212
 
213
  # Check 3: list comprehension used (loop simplified)
214
  has_listcomp = any(isinstance(n, ast.ListComp) for n in ast.walk(tree))
215
  if has_listcomp:
216
+ score += 0.25
217
+
218
+ # Check 4: obvious dead/unused sentinel names removed
219
+ if all(name not in source for name in ["unused_var", "unused", "dead", "unused2"]):
220
+ score += 0.25
221
 
222
  return score
223
 
 
225
  # ---------------------------------------------------------------------------
226
  # Task 3 — Hard: Full refactor
227
  # ---------------------------------------------------------------------------
228
+ _HARD_SAMPLES: List[str] = [
229
+ """\
230
  def add(p, q):
231
  return p + q
232
 
 
242
  flag = not not True
243
  return val
244
  print("dead")
245
+ """,
246
+ """\
247
+ def helper(a, b):
248
+ return a + b
249
+
250
+ def pipeline(tmp, xs, x):
251
+ out = []
252
+ for i in xs:
253
+ out.append(i * 2)
254
+ if True:
255
+ y = helper(tmp, x)
256
+ if False:
257
+ y = 0
258
+ return y
259
+ y = 123
260
+ """,
261
+ """\
262
+ def add(p, q):
263
+ return p + q
264
+
265
+ def compute(x, data, tmp):
266
+ result = []
267
+ for item in data:
268
+ result.append(item * 2)
269
+ if False:
270
+ print("never")
271
+ val = add(x, tmp)
272
+ return val
273
+ """,
274
+ """\
275
+ def add(p, q):
276
+ return p + q
277
+
278
+ def compute(x, data, tmp):
279
+ res = []
280
+ for item in data:
281
+ res.append(item * 2)
282
+ flag = not not True
283
+ if True:
284
+ return add(x, tmp)
285
+ """,
286
+ """\
287
+ def plus(p, q):
288
+ return p + q
289
+
290
+ def compute(tmp, data, x):
291
+ out = []
292
+ for item in data:
293
+ out.append(item * 2)
294
+ if False:
295
+ tmp = 999
296
+ if True:
297
+ val = plus(x, tmp)
298
+ return val
299
+ """,
300
+ ]
301
 
302
 
303
  def _grade_hard(code: str) -> float:
304
+ """Score = fraction of 7 quality checks passed."""
305
  try:
306
  tree = ast.parse(code)
307
  except SyntaxError:
308
  return 0.0
309
 
310
+ source = _safe_unparse(tree)
311
  checks = 0
312
 
313
+ # 1. No generic variable names x/tmp/i in function signature
314
  has_generic = False
315
 
316
  class _GenCheck(ast.NodeVisitor):
317
  def visit_arg(self, node: ast.arg) -> None:
318
  nonlocal has_generic
319
+ if node.arg in {"x", "tmp", "i"}:
320
  has_generic = True
321
 
322
  _GenCheck().visit(tree)
323
  if not has_generic:
324
  checks += 1
325
 
326
+ # 2. No if/while False block
327
+ if ("if False" not in source) and ("while False" not in source):
328
  checks += 1
329
 
330
  # 3. if True removed (body inlined)
 
335
  if any(isinstance(n, ast.ListComp) for n in ast.walk(tree)):
336
  checks += 1
337
 
338
+ # 5. helper calls inlined (no call sites remain)
339
  calls = [n for n in ast.walk(tree) if isinstance(n, ast.Call)]
340
  fn_names = {c.func.id for c in calls if isinstance(c.func, ast.Name)}
341
+ if not ({"add", "plus", "helper"} & fn_names):
342
+ checks += 1
343
+
344
+ # 6. no unreachable after return/raise
345
+ if not _tree_has_unreachable(tree):
346
+ checks += 1
347
+
348
+ # 7. remove double-not
349
+ if "not not" not in source:
350
  checks += 1
351
 
352
+ return checks / 7
353
 
354
 
355
  # ---------------------------------------------------------------------------
 
367
  name="Rename Variables (Easy)",
368
  description="Rename generic variable names (x, tmp) to descriptive ones",
369
  difficulty="easy",
370
+ samples=_EASY_SAMPLES,
371
  _grade_fn=_grade_easy,
372
  )
373
  self._tasks["remove_dead_code"] = Task(
 
375
  name="Remove Dead Code (Medium)",
376
  description="Remove unreachable code, if False blocks, and unused variables",
377
  difficulty="medium",
378
+ samples=_MEDIUM_SAMPLES,
379
  _grade_fn=_grade_medium,
380
  )
381
  self._tasks["full_refactor"] = Task(
 
383
  name="Full Refactor (Hard)",
384
  description="Apply all transformations: rename, dead code, loops, conditions, inlining",
385
  difficulty="hard",
386
+ samples=_HARD_SAMPLES,
387
  _grade_fn=_grade_hard,
388
  )
389
 
inference.py CHANGED
@@ -1,17 +1,18 @@
1
  """
2
  ACRE inference script for OpenEnv submission evaluation.
3
 
4
- Required environment variables:
5
- API_BASE_URL: LLM API endpoint (default allowed)
6
- MODEL_NAME: model identifier (default allowed)
7
- HF_TOKEN: API token for the OpenAI-compatible endpoint
8
- ENV_URL: running ACRE server base URL
9
-
10
- Optional:
11
- LOCAL_IMAGE_NAME: present for evaluator compatibility when using a local
12
- Docker image launcher.
13
-
14
- Stdout format uses strict START / STEP / END event markers.
 
15
  """
16
  from __future__ import annotations
17
 
@@ -20,7 +21,7 @@ import os
20
  import re
21
  import sys
22
  import time
23
- from typing import Dict, List, Tuple
24
 
25
  import requests
26
  from openai import OpenAI
@@ -95,7 +96,7 @@ def grade(task_id: str, code: str) -> float:
95
  return float(response.json().get("score", 0.0))
96
 
97
 
98
- def choose_action(client: OpenAI, state: dict, task_id: str) -> Tuple[int, str]:
99
  def heuristic_action() -> Tuple[int, str]:
100
  code = str(state.get("current_code", ""))
101
  step_i = int(state.get("episode_steps", 0))
@@ -141,7 +142,8 @@ def choose_action(client: OpenAI, state: dict, task_id: str) -> Tuple[int, str]:
141
  return 1, "heuristic: remove remaining dead code"
142
  return 3, "heuristic: condition optimization as safe default"
143
 
144
- if not HF_TOKEN:
 
145
  return heuristic_action()
146
 
147
  messages = [
@@ -184,23 +186,12 @@ def choose_action(client: OpenAI, state: dict, task_id: str) -> Tuple[int, str]:
184
  return heuristic_action()
185
 
186
 
187
- def run_episode(client: OpenAI, task_id: str, episode_num: int) -> float:
188
  reset_env(task_id)
189
  state = get_state()
190
 
191
- print(
192
- json.dumps(
193
- {
194
- "event": "START",
195
- "episode": episode_num,
196
- "task_id": task_id,
197
- "initial_complexity": state.get("complexity", 0),
198
- "initial_code_length": len(state.get("current_code", "")),
199
- "timestamp": time.time(),
200
- }
201
- ),
202
- flush=True,
203
- )
204
 
205
  cumulative_reward = 0.0
206
 
@@ -214,25 +205,8 @@ def run_episode(client: OpenAI, task_id: str, episode_num: int) -> float:
214
  norm_reward = float(reward_payload.get("normalized", (raw_reward + 32) / 52))
215
  cumulative_reward += raw_reward
216
 
217
- print(
218
- json.dumps(
219
- {
220
- "event": "STEP",
221
- "episode": episode_num,
222
- "step": step_num,
223
- "action": action,
224
- "action_name": ACTION_MEANINGS.get(action, "unknown"),
225
- "reason": reason,
226
- "reward": round(raw_reward, 4),
227
- "normalized_reward": round(norm_reward, 4),
228
- "cumulative_reward": round(cumulative_reward, 4),
229
- "changed": result.get("info", {}).get("changed", False),
230
- "reward_components": reward_payload.get("components", {}),
231
- "done": result.get("done", False),
232
- }
233
- ),
234
- flush=True,
235
- )
236
 
237
  if result.get("done") or result.get("terminated") or result.get("truncated"):
238
  break
@@ -240,21 +214,8 @@ def run_episode(client: OpenAI, task_id: str, episode_num: int) -> float:
240
  final_state = get_state()
241
  task_score = grade(task_id, final_state.get("current_code", ""))
242
 
243
- print(
244
- json.dumps(
245
- {
246
- "event": "END",
247
- "episode": episode_num,
248
- "task_id": task_id,
249
- "cumulative_reward": round(cumulative_reward, 4),
250
- "normalized_cumulative": round((cumulative_reward + 32) / 52, 4),
251
- "task_score": round(task_score, 4),
252
- "final_complexity": final_state.get("complexity", 0),
253
- "timestamp": time.time(),
254
- }
255
- ),
256
- flush=True,
257
- )
258
 
259
  return task_score
260
 
@@ -263,7 +224,9 @@ def main() -> None:
263
  if not ENV_URL:
264
  raise SystemExit("ENV_URL is required. Example: ENV_URL=http://localhost:7860")
265
 
266
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "dummy")
 
 
267
 
268
  scores: List[float] = []
269
  for i, task_id in enumerate(TASKS, start=1):
 
1
  """
2
  ACRE inference script for OpenEnv submission evaluation.
3
 
4
+ Environment variables:
5
+ - API_BASE_URL: LLM API endpoint (default allowed)
6
+ - MODEL_NAME: model identifier (default allowed)
7
+ - HF_TOKEN: API token for the OpenAI-compatible endpoint (NO default)
8
+ - ENV_URL: running ACRE server base URL (required)
9
+ - LOCAL_IMAGE_NAME: present for evaluator compatibility (optional)
10
+ - USE_LLM: set to "1" to enable LLM action selection when HF_TOKEN is set
11
+
12
+ STRICT stdout format (do not change):
13
+ START <task_id>
14
+ STEP <action_int>
15
+ END <score_float>
16
  """
17
  from __future__ import annotations
18
 
 
21
  import re
22
  import sys
23
  import time
24
+ from typing import Dict, List, Optional, Tuple
25
 
26
  import requests
27
  from openai import OpenAI
 
96
  return float(response.json().get("score", 0.0))
97
 
98
 
99
+ def choose_action(client: Optional[OpenAI], state: dict, task_id: str) -> Tuple[int, str]:
100
  def heuristic_action() -> Tuple[int, str]:
101
  code = str(state.get("current_code", ""))
102
  step_i = int(state.get("episode_steps", 0))
 
142
  return 1, "heuristic: remove remaining dead code"
143
  return 3, "heuristic: condition optimization as safe default"
144
 
145
+ use_llm = bool(HF_TOKEN) and os.getenv("USE_LLM", "0") == "1"
146
+ if (not use_llm) or client is None:
147
  return heuristic_action()
148
 
149
  messages = [
 
186
  return heuristic_action()
187
 
188
 
189
+ def run_episode(client: Optional[OpenAI], task_id: str, episode_num: int) -> float:
190
  reset_env(task_id)
191
  state = get_state()
192
 
193
+ # STRICT logging format required by evaluator.
194
+ print(f"START {task_id}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  cumulative_reward = 0.0
197
 
 
205
  norm_reward = float(reward_payload.get("normalized", (raw_reward + 32) / 52))
206
  cumulative_reward += raw_reward
207
 
208
+ # STRICT logging format required by evaluator.
209
+ print(f"STEP {int(action)}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  if result.get("done") or result.get("terminated") or result.get("truncated"):
212
  break
 
214
  final_state = get_state()
215
  task_score = grade(task_id, final_state.get("current_code", ""))
216
 
217
+ # STRICT logging format required by evaluator.
218
+ print(f"END {task_score:.4f}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  return task_score
221
 
 
224
  if not ENV_URL:
225
  raise SystemExit("ENV_URL is required. Example: ENV_URL=http://localhost:7860")
226
 
227
+ client: Optional[OpenAI] = None
228
+ if HF_TOKEN and os.getenv("USE_LLM", "0") == "1":
229
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
230
 
231
  scores: List[float] = []
232
  for i, task_id in enumerate(TASKS, start=1):
openenv_interface.py CHANGED
@@ -62,7 +62,23 @@ class OpenEnvRefactorEnv(OpenEnvBase):
62
  task = self._registry.get_task(task_id)
63
  if task is None:
64
  raise ValueError(f"Task '{task_id}' not found")
65
- initial_code = task.initial_code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  if initial_code is None:
68
  return None
 
62
  task = self._registry.get_task(task_id)
63
  if task is None:
64
  raise ValueError(f"Task '{task_id}' not found")
65
+ # Load a multi-sample dataset for this task. Sample selection is
66
+ # deterministic given the `seed` passed to `reset()`.
67
+ samples = list(getattr(task, "samples", []) or [])
68
+ if not samples:
69
+ initial_code = task.initial_code
70
+ else:
71
+ self._env.dataset = CodeSampleDataset(
72
+ [
73
+ CodeSample(
74
+ id=f"{task_id}:{i}",
75
+ language="python",
76
+ code=str(src),
77
+ )
78
+ for i, src in enumerate(samples)
79
+ ]
80
+ )
81
+ return None
82
 
83
  if initial_code is None:
84
  return None
validate.py CHANGED
@@ -204,8 +204,16 @@ def run_validation(base_url: str) -> int:
204
  try:
205
  inference_src = read_text("inference.py")
206
  failures += 0 if check("inference.py exists", True) else 1
207
- for marker in ['"event": "START"', '"event": "STEP"', '"event": "END"']:
208
- failures += 0 if check(f"inference.py emits {marker}", marker in inference_src) else 1
 
 
 
 
 
 
 
 
209
  failures += 0 if check(
210
  "Uses OpenAI client",
211
  "from openai import OpenAI" in inference_src,
 
204
  try:
205
  inference_src = read_text("inference.py")
206
  failures += 0 if check("inference.py exists", True) else 1
207
+ # Accept either the older JSON event markers or the strict hackathon
208
+ # line-based format:
209
+ # START <task_id>
210
+ # STEP <action>
211
+ # END <score>
212
+ json_markers_ok = all(m in inference_src for m in ['"event": "START"', '"event": "STEP"', '"event": "END"'])
213
+ line_markers_ok = all(m in inference_src for m in ["START ", "STEP ", "END "])
214
+ failures += 0 if check("inference.py emits START marker", json_markers_ok or line_markers_ok) else 1
215
+ failures += 0 if check("inference.py emits STEP marker", json_markers_ok or line_markers_ok) else 1
216
+ failures += 0 if check("inference.py emits END marker", json_markers_ok or line_markers_ok) else 1
217
  failures += 0 if check(
218
  "Uses OpenAI client",
219
  "from openai import OpenAI" in inference_src,