PRANAV05092003 commited on
Commit
bc5030f
·
1 Parent(s): 8422246

Fixed structure (moved files to root)

Browse files
.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ *.egg-info/
7
+ dist/
8
+ build/
9
+ .env
10
+ .venv
11
+ venv/
12
+ *.zip
13
+ acre_agent.zip
14
+ *.log
15
+ .DS_Store
16
+ .deps/
17
+ libs/
18
+ numpy.libs/
19
+ *.dll
20
+ *.so
21
+ *.dylib
22
+ env/
23
+ ENV/
24
+ .cache/
25
+ .huggingface/
26
+ Thumbs.db
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ COPY . .
13
+
14
+ ENV API_BASE_URL=https://api.openai.com/v1
15
+ ENV MODEL_NAME=gpt-4o-mini
16
+ ENV PORT=7860
17
+
18
+ EXPOSE 7860
19
+
20
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
21
+ CMD python -c "import requests; requests.get('http://localhost:7860/').raise_for_status()"
22
+
23
+ CMD ["python", "server.py"]
README.md CHANGED
@@ -1,10 +1,174 @@
1
  ---
2
- title: Autonomous Code Refactoring Env
3
- emoji: ⚡
4
  colorFrom: blue
5
- colorTo: red
6
  sdk: docker
 
7
  pinned: false
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ACRE - Autonomous Code Refactoring Environment
 
3
  colorFrom: blue
4
+ colorTo: green
5
  sdk: docker
6
+ app_port: 7860
7
  pinned: false
8
+ license: mit
9
+ tags:
10
+ - openenv
11
  ---
12
 
13
+ # ACRE - Autonomous Code Refactoring Environment
14
+
15
+ ACRE is an OpenEnv-compatible environment for autonomous Python code refactoring. An agent receives real code-cleanup tasks and must improve the code through AST-based transformations while receiving dense reward feedback for correctness, simplification, and performance.
16
+
17
+ ## Environment Overview and Motivation
18
+
19
+ This project simulates a realistic developer workflow: cleaning up messy Python code, removing dead logic, simplifying loops, and inlining trivial helpers. The canonical OpenEnv wrapper lives in `openenv_interface.py`, while the original Gymnasium-compatible environment remains available for RL training and demos.
20
+
21
+ ## Definitions of Action and Observation Spaces
22
+
23
+ ### Action Space - Discrete(5)
24
+
25
+ | Action | Name | Description |
26
+ |---|---|---|
27
+ | 0 | rename_variable | Rename generic variables like `x`, `tmp`, and `i` |
28
+ | 1 | remove_dead_code | Remove unreachable statements, `if False` branches, and unused assignments |
29
+ | 2 | simplify_loop | Convert append-loops into list comprehensions |
30
+ | 3 | optimize_condition | Simplify `not not x`, `if True`, `if False`, and boolean comparisons |
31
+ | 4 | inline_function | Inline simple single-return module-level functions |
32
+
33
+ ### Observation Space - Box(4,)
34
+
35
+ The environment tracks:
36
+
37
+ - `code_length`
38
+ - `complexity_score`
39
+ - `runtime_s`
40
+ - `error_flag`
41
+
42
+ ### Typed OpenEnv Models
43
+
44
+ The submission-facing interface uses Pydantic models in `models.py`:
45
+
46
+ - `ObservationModel`
47
+ - `ActionModel`
48
+ - `RewardModel`
49
+ - `StateResponse`
50
+
51
+ The canonical interface is:
52
+
53
+ ```python
54
+ observation = env.reset(...)
55
+ observation, reward, done, info = env.step(action)
56
+ state = env.state()
57
+ ```
58
+
59
+ ## Task Descriptions with Expected Difficulty Levels
60
+
61
+ | Task ID | Difficulty | Objective |
62
+ |---|---|---|
63
+ | `rename_variables` | Easy | Remove generic variable names from the snippet |
64
+ | `remove_dead_code` | Medium | Eliminate dead branches, unreachable code, and unused assignments |
65
+ | `full_refactor` | Hard | Combine renaming, dead-code removal, loop simplification, condition optimization, and inlining |
66
+
67
+ Each task includes a deterministic AST-based grader returning a score in `[0.0, 1.0]`.
68
+
69
+ ## Reward Design
70
+
71
+ Rewards are shaped throughout the trajectory instead of only at the end.
72
+
73
+ - Success reward for syntactically valid, executable output
74
+ - Complexity reward when control-flow complexity decreases
75
+ - Performance reward when runtime improves
76
+ - Error penalty for invalid or failing code
77
+ - No-change penalty to discourage loops and unproductive actions
78
+
79
+ Raw reward range is `[-32, 20]`, normalized to `[0.0, 1.0]` with `(raw + 32) / 52`.
80
+
81
+ ## HTTP API
82
+
83
+ | Method | Path | Purpose |
84
+ |---|---|---|
85
+ | GET | `/` | Health check |
86
+ | GET | `/health` | Compatibility health check |
87
+ | POST | `/reset` | Reset environment and return typed observation/state |
88
+ | POST | `/step` | Apply one action and return typed observation/reward/done |
89
+ | GET | `/state` | Return the current typed state |
90
+ | GET | `/tasks` | List available tasks |
91
+ | POST | `/tasks/{task_id}/grade` | Grade submitted code |
92
+
93
+ ## Setup and Usage Instructions
94
+
95
+ ### Local setup
96
+
97
+ ```bash
98
+ pip install -r requirements.txt
99
+ python server.py
100
+ ```
101
+
102
+ ### Baseline inference
103
+
104
+ Set environment variables before running:
105
+
106
+ ```bash
107
+ export API_BASE_URL=https://api.openai.com/v1
108
+ export MODEL_NAME=gpt-4o-mini
109
+ export HF_TOKEN=your_key
110
+ export ENV_URL=http://localhost:7860
111
+ python inference.py
112
+ ```
113
+
114
+ Notes:
115
+
116
+ - `API_BASE_URL` and `MODEL_NAME` have defaults in `inference.py`
117
+ - `HF_TOKEN` is optional because the script falls back to a deterministic heuristic baseline
118
+ - `LOCAL_IMAGE_NAME` is read for evaluator compatibility when using a local Docker image launcher
119
+
120
+ ### Docker / Hugging Face Spaces
121
+
122
+ ```bash
123
+ docker build -t acre .
124
+ docker run -p 7860:7860 \
125
+ -e API_BASE_URL=https://api.openai.com/v1 \
126
+ -e MODEL_NAME=gpt-4o-mini \
127
+ -e HF_TOKEN=your_key \
128
+ -e ENV_URL=http://localhost:7860 \
129
+ acre
130
+ ```
131
+
132
+ The repository is configured for a Docker-based Hugging Face Space and includes the `openenv` tag in the front matter.
133
+
134
+ ## Validation
135
+
136
+ Run the repository validator:
137
+
138
+ ```bash
139
+ python validate.py --url http://localhost:7860
140
+ ```
141
+
142
+ When using the official hackathon tooling, also run:
143
+
144
+ ```bash
145
+ openenv validate
146
+ ```
147
+
148
+ ## Interactive Demo
149
+
150
+ Start the server and open:
151
+
152
+ ```text
153
+ http://localhost:7860/demo
154
+ ```
155
+
156
+ The demo shows:
157
+
158
+ - Original code
159
+ - Optimized code
160
+ - Unified diff
161
+ - Per-step action and reward logs
162
+
163
+ ## Baseline Performance Scores
164
+
165
+ The deterministic fallback policy used by `inference.py` produces the following reproducible task scores:
166
+
167
+ | Task | Score |
168
+ |---|---|
169
+ | `rename_variables` | 1.0 |
170
+ | `remove_dead_code` | 1.0 |
171
+ | `full_refactor` | 1.0 |
172
+ | Average | 1.0 |
173
+
174
+ These scores come from the built-in heuristic policy with `HF_TOKEN` unset, which keeps the baseline reproducible across runs.
acre/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACRE (Autonomous Code Refactoring Environment).
3
+
4
+ Package skeleton for an RL-based code refactoring system.
5
+ """
6
+
7
+ __all__ = [
8
+ "env",
9
+ "actions",
10
+ "datasets",
11
+ "training",
12
+ "utils",
13
+ ]
14
+
acre/actions/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Action definitions and transformations for ACRE."""
2
+
3
+ from .transformations import Transformation, TransformationResult
4
+
5
+ __all__ = ["Transformation", "TransformationResult"]
6
+
acre/actions/transformations.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import copy
5
+ from dataclasses import dataclass
6
+ from itertools import zip_longest
7
+ from typing import Any, Dict, Protocol
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class TransformationResult:
12
+ """Output of applying a transformation (placeholder)."""
13
+
14
+ code: str
15
+ changed: bool
16
+ metadata: Dict[str, Any]
17
+
18
+
19
+ class Transformation(Protocol):
20
+ """Protocol for a code transformation."""
21
+
22
+ name: str
23
+
24
+ def apply(self, code: str) -> TransformationResult: ...
25
+
26
+
27
+ def noop_transformation(code: str) -> TransformationResult:
28
+ """Baseline transformation that leaves code unchanged."""
29
+ return TransformationResult(code=code, changed=False, metadata={"kind": "noop"})
30
+
31
+
32
+ def _finalize_result(*, original: str, out: str, meta: Dict[str, Any]) -> TransformationResult:
33
+ """
34
+ Standardize metadata across transformations.
35
+
36
+ - Adds `lines_changed` and `impact` for explainability/metrics.
37
+ - Ensures formatting-only changes don't count as `changed`.
38
+ """
39
+
40
+ def _count_lines_changed(a: str, b: str) -> int:
41
+ a_lines = a.splitlines()
42
+ b_lines = b.splitlines()
43
+ changed = 0
44
+ for x, y in zip_longest(a_lines, b_lines, fillvalue=None):
45
+ if x != y:
46
+ changed += 1
47
+ return int(changed)
48
+
49
+ lines_changed = _count_lines_changed(original, out)
50
+
51
+ # Fallback identity check: AST round-trips can reformat without changing meaning.
52
+ # If the textual content is the same after stripping, treat it as unchanged.
53
+ if out.strip() == original.strip():
54
+ meta["success"] = False
55
+ meta["lines_changed"] = 0
56
+ meta["impact"] = "low"
57
+ return TransformationResult(code=original, changed=False, metadata=meta)
58
+
59
+ meta["lines_changed"] = lines_changed
60
+ meta["impact"] = "high" if lines_changed >= 3 else "low"
61
+ meta["success"] = True
62
+ return TransformationResult(code=out, changed=True, metadata=meta)
63
+
64
+
65
+ def _unchanged(*, code: str, meta: Dict[str, Any]) -> TransformationResult:
66
+ meta.setdefault("success", False)
67
+ meta.setdefault("lines_changed", 0)
68
+ meta.setdefault("impact", "low")
69
+ return TransformationResult(code=code, changed=False, metadata=meta)
70
+
71
+
72
+ def rename_variable(code: str) -> TransformationResult:
73
+ """
74
+ Rename simple, generic variable names to more descriptive ones.
75
+
76
+ Hackathon-scope heuristic:
77
+ - Rename generic names in priority order: x, tmp, i.
78
+ - Uses descriptive base names and avoids collisions.
79
+ - Applies to Name nodes and function args.
80
+ """
81
+ meta: Dict[str, Any] = {"type": "rename_variable", "success": False}
82
+ try:
83
+ tree = ast.parse(code)
84
+
85
+ class _NameCollector(ast.NodeVisitor):
86
+ def __init__(self) -> None:
87
+ self.names: set[str] = set()
88
+
89
+ def visit_Name(self, node: ast.Name) -> None: # noqa: N802
90
+ self.names.add(node.id)
91
+
92
+ def visit_arg(self, node: ast.arg) -> None: # noqa: N802
93
+ self.names.add(node.arg)
94
+
95
+ collector = _NameCollector()
96
+ collector.visit(tree)
97
+
98
+ rename_plan = [
99
+ ("x", "value"),
100
+ ("tmp", "temp_value"),
101
+ ("i", "index"),
102
+ ]
103
+
104
+ old = ""
105
+ base_new = "value"
106
+ for candidate_old, candidate_base in rename_plan:
107
+ if candidate_old in collector.names:
108
+ old = candidate_old
109
+ base_new = candidate_base
110
+ break
111
+
112
+ if not old:
113
+ return _unchanged(code=code, meta=meta)
114
+
115
+ new = base_new
116
+ i = 1
117
+ while new in collector.names:
118
+ new = f"{base_new}{i}"
119
+ i += 1
120
+
121
+ class _Renamer(ast.NodeTransformer):
122
+ def __init__(self, old_name: str, new_name: str) -> None:
123
+ self.old_name = old_name
124
+ self.new_name = new_name
125
+ self.changed = False
126
+
127
+ def visit_Name(self, node: ast.Name) -> ast.AST: # noqa: N802
128
+ if node.id == self.old_name:
129
+ self.changed = True
130
+ return ast.copy_location(ast.Name(id=self.new_name, ctx=node.ctx), node)
131
+ return node
132
+
133
+ def visit_arg(self, node: ast.arg) -> ast.AST: # noqa: N802
134
+ if node.arg == self.old_name:
135
+ self.changed = True
136
+ new_node = copy.copy(node)
137
+ new_node.arg = self.new_name
138
+ return new_node
139
+ return node
140
+
141
+ renamer = _Renamer(old, new)
142
+ tree = renamer.visit(tree)
143
+ ast.fix_missing_locations(tree)
144
+
145
+ if not renamer.changed:
146
+ return _unchanged(code=code, meta=meta)
147
+
148
+ out = ast.unparse(tree)
149
+ meta["old"] = old
150
+ meta["new"] = new
151
+ # Renames tend to be small diffs; label as low impact unless the diff is large.
152
+ return _finalize_result(original=code, out=out, meta=meta)
153
+ except Exception:
154
+ return _unchanged(code=code, meta=meta)
155
+
156
+
157
+ def remove_dead_code(code: str) -> TransformationResult:
158
+ """
159
+ Remove simple dead code patterns.
160
+
161
+ Hackathon-scope heuristics:
162
+ - Drop statements after `return` / `raise` in the same block.
163
+ - Remove `if False: ...` blocks (keep `else` if present).
164
+ - Remove assignments to unused names in a block (very simple check).
165
+ """
166
+ meta: Dict[str, Any] = {"type": "remove_dead_code", "success": False}
167
+
168
+ try:
169
+ tree = ast.parse(code)
170
+
171
+ def _is_const_bool(expr: ast.AST, value: bool) -> bool:
172
+ return isinstance(expr, ast.Constant) and isinstance(expr.value, bool) and expr.value is value
173
+
174
+ class _LoadNameCollector(ast.NodeVisitor):
175
+ def __init__(self) -> None:
176
+ self.loaded: set[str] = set()
177
+
178
+ def visit_Name(self, node: ast.Name) -> None: # noqa: N802
179
+ if isinstance(node.ctx, ast.Load):
180
+ self.loaded.add(node.id)
181
+
182
+ class _DeadCode(ast.NodeTransformer):
183
+ def __init__(self) -> None:
184
+ self.changed = False
185
+
186
+ def _prune_unreachable(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
187
+ out: list[ast.stmt] = []
188
+ unreachable = False
189
+ for s in stmts:
190
+ if unreachable:
191
+ self.changed = True
192
+ continue
193
+ out.append(s)
194
+ if isinstance(s, (ast.Return, ast.Raise)):
195
+ unreachable = True
196
+ return out
197
+
198
+ def _remove_unused_assigns(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
199
+ collector = _LoadNameCollector()
200
+ for s in stmts:
201
+ collector.visit(s)
202
+ used = collector.loaded
203
+
204
+ out: list[ast.stmt] = []
205
+ for s in stmts:
206
+ if isinstance(s, ast.Assign) and all(isinstance(t, ast.Name) for t in s.targets):
207
+ targets = [t.id for t in s.targets if isinstance(t, ast.Name)]
208
+ # Remove only if *all* assigned names are unused.
209
+ if targets and all(t not in used for t in targets):
210
+ self.changed = True
211
+ continue
212
+ if isinstance(s, ast.AnnAssign) and isinstance(s.target, ast.Name):
213
+ if s.target.id not in used:
214
+ self.changed = True
215
+ continue
216
+ out.append(s)
217
+ return out
218
+
219
+ def _clean_block(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
220
+ # First apply transformations inside statements.
221
+ visited = [self.visit(s) for s in stmts]
222
+ flat: list[ast.stmt] = []
223
+ for s in visited:
224
+ if s is None:
225
+ self.changed = True
226
+ continue
227
+ if isinstance(s, list):
228
+ flat.extend([x for x in s if isinstance(x, ast.stmt)])
229
+ self.changed = True
230
+ else:
231
+ flat.append(s)
232
+
233
+ flat = self._prune_unreachable(flat)
234
+ flat = self._remove_unused_assigns(flat)
235
+ return flat
236
+
237
+ def visit_Module(self, node: ast.Module) -> ast.AST: # noqa: N802
238
+ node.body = self._clean_block(node.body)
239
+ return node
240
+
241
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: # noqa: N802
242
+ node.body = self._clean_block(node.body)
243
+ return node
244
+
245
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: # noqa: N802
246
+ node.body = self._clean_block(node.body)
247
+ return node
248
+
249
+ def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]: # noqa: N802
250
+ node = self.generic_visit(node)
251
+ if _is_const_bool(node.test, False):
252
+ self.changed = True
253
+ return node.orelse or []
254
+ return node
255
+
256
+ def visit_While(self, node: ast.While) -> ast.AST | None: # noqa: N802
257
+ node = self.generic_visit(node)
258
+ if _is_const_bool(node.test, False):
259
+ self.changed = True
260
+ return None
261
+ return node
262
+
263
+ dc = _DeadCode()
264
+ tree = dc.visit(tree)
265
+ ast.fix_missing_locations(tree)
266
+ if not dc.changed:
267
+ return _unchanged(code=code, meta=meta)
268
+
269
+ out = ast.unparse(tree)
270
+ return _finalize_result(original=code, out=out, meta=meta)
271
+ except Exception:
272
+ return _unchanged(code=code, meta=meta)
273
+
274
+
275
+ def simplify_loops(code: str) -> TransformationResult:
276
+ """
277
+ Simplify very basic loop patterns into more pythonic forms.
278
+
279
+ Supported pattern (only when adjacent in the same block):
280
+ - xs = []
281
+ for t in it:
282
+ xs.append(expr)
283
+ => xs = [expr for t in it]
284
+ """
285
+ meta: Dict[str, Any] = {"type": "simplify_loops", "success": False}
286
+ try:
287
+ tree = ast.parse(code)
288
+
289
+ class _LoopSimplifier(ast.NodeTransformer):
290
+ def __init__(self) -> None:
291
+ self.changed = False
292
+
293
+ def _simplify_body(self, body: list[ast.stmt]) -> list[ast.stmt]:
294
+ out: list[ast.stmt] = []
295
+ i = 0
296
+ while i < len(body):
297
+ cur = body[i]
298
+ nxt = body[i + 1] if i + 1 < len(body) else None
299
+
300
+ if (
301
+ isinstance(cur, ast.Assign)
302
+ and len(cur.targets) == 1
303
+ and isinstance(cur.targets[0], ast.Name)
304
+ and isinstance(cur.value, ast.List)
305
+ and cur.value.elts == []
306
+ and isinstance(nxt, ast.For)
307
+ and len(nxt.body) == 1
308
+ and isinstance(nxt.body[0], ast.Expr)
309
+ and isinstance(nxt.body[0].value, ast.Call)
310
+ ):
311
+ list_name = cur.targets[0].id
312
+ call = nxt.body[0].value
313
+ if (
314
+ isinstance(call.func, ast.Attribute)
315
+ and isinstance(call.func.value, ast.Name)
316
+ and call.func.value.id == list_name
317
+ and call.func.attr == "append"
318
+ and len(call.args) == 1
319
+ and not call.keywords
320
+ ):
321
+ # Build list comprehension: [call.args[0] for <target> in <iter>]
322
+ comp = ast.ListComp(
323
+ elt=call.args[0],
324
+ generators=[
325
+ ast.comprehension(
326
+ target=nxt.target,
327
+ iter=nxt.iter,
328
+ ifs=[],
329
+ is_async=0,
330
+ )
331
+ ],
332
+ )
333
+ new_assign = ast.Assign(targets=[ast.Name(id=list_name, ctx=ast.Store())], value=comp)
334
+ out.append(ast.copy_location(new_assign, cur))
335
+ self.changed = True
336
+ i += 2
337
+ continue
338
+
339
+ out.append(cur)
340
+ i += 1
341
+
342
+ return out
343
+
344
+ def visit_Module(self, node: ast.Module) -> ast.AST: # noqa: N802
345
+ node = self.generic_visit(node)
346
+ node.body = self._simplify_body(node.body)
347
+ return node
348
+
349
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: # noqa: N802
350
+ node = self.generic_visit(node)
351
+ node.body = self._simplify_body(node.body)
352
+ return node
353
+
354
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: # noqa: N802
355
+ node = self.generic_visit(node)
356
+ node.body = self._simplify_body(node.body)
357
+ return node
358
+
359
+ simp = _LoopSimplifier()
360
+ tree = simp.visit(tree)
361
+ ast.fix_missing_locations(tree)
362
+ if not simp.changed:
363
+ return _unchanged(code=code, meta=meta)
364
+
365
+ out = ast.unparse(tree)
366
+ return _finalize_result(original=code, out=out, meta=meta)
367
+ except Exception:
368
+ return _unchanged(code=code, meta=meta)
369
+
370
+
371
+ def simplify_loop(code: str) -> TransformationResult:
372
+ # Backwards-compatible alias for the environment's action mapping.
373
+ return simplify_loops(code)
374
+
375
+
376
+ def optimize_condition(code: str) -> TransformationResult:
377
+ """
378
+ Simplify redundant boolean conditions.
379
+
380
+ Hackathon-scope heuristics:
381
+ - Replace `if True:` with its body; `if False:` with `else` (if present).
382
+ - Simplify `not not X` -> `X`.
383
+ - Simplify comparisons to True/False: `X == True` -> `X`, `X == False` -> `not X`.
384
+ """
385
+ meta: Dict[str, Any] = {"type": "optimize_condition", "success": False}
386
+ try:
387
+ tree = ast.parse(code)
388
+
389
+ def _is_bool_const(node: ast.AST, value: bool) -> bool:
390
+ return isinstance(node, ast.Constant) and isinstance(node.value, bool) and node.value is value
391
+
392
+ class _CondOpt(ast.NodeTransformer):
393
+ def __init__(self) -> None:
394
+ self.changed = False
395
+
396
+ def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST: # noqa: N802
397
+ node = self.generic_visit(node)
398
+ if isinstance(node.op, ast.Not) and isinstance(node.operand, ast.UnaryOp) and isinstance(node.operand.op, ast.Not):
399
+ self.changed = True
400
+ return node.operand.operand
401
+ return node
402
+
403
+ def visit_Compare(self, node: ast.Compare) -> ast.AST: # noqa: N802
404
+ node = self.generic_visit(node)
405
+ if len(node.ops) == 1 and len(node.comparators) == 1:
406
+ op = node.ops[0]
407
+ rhs = node.comparators[0]
408
+ if isinstance(op, (ast.Eq, ast.Is)) and _is_bool_const(rhs, True):
409
+ self.changed = True
410
+ return node.left
411
+ if isinstance(op, (ast.Eq, ast.Is)) and _is_bool_const(rhs, False):
412
+ self.changed = True
413
+ return ast.UnaryOp(op=ast.Not(), operand=node.left)
414
+ return node
415
+
416
+ def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]: # noqa: N802
417
+ node = self.generic_visit(node)
418
+ if _is_bool_const(node.test, True):
419
+ self.changed = True
420
+ return node.body
421
+ if _is_bool_const(node.test, False):
422
+ self.changed = True
423
+ return node.orelse or []
424
+ return node
425
+
426
+ opt = _CondOpt()
427
+ tree = opt.visit(tree)
428
+ ast.fix_missing_locations(tree)
429
+ if not opt.changed:
430
+ return _unchanged(code=code, meta=meta)
431
+
432
+ out = ast.unparse(tree)
433
+ return _finalize_result(original=code, out=out, meta=meta)
434
+ except Exception:
435
+ return _unchanged(code=code, meta=meta)
436
+
437
+
438
+ def inline_function(code: str) -> TransformationResult:
439
+ """
440
+ Inline very simple functions into their call sites.
441
+
442
+ Supported pattern:
443
+ - def f(a, b): return <expr using only a,b>
444
+ - Replace calls: f(x, y) -> <expr with a->x, b->y>
445
+ Only handles module-level functions and positional args.
446
+ """
447
+ meta: Dict[str, Any] = {"type": "inline_function", "success": False}
448
+ try:
449
+ tree = ast.parse(code)
450
+
451
+ simple_fns: Dict[str, tuple[list[str], ast.AST]] = {}
452
+ for node in tree.body:
453
+ if not isinstance(node, ast.FunctionDef):
454
+ continue
455
+ if node.decorator_list:
456
+ continue
457
+ args = node.args
458
+ if args.vararg or args.kwarg or args.kwonlyargs or args.defaults or args.posonlyargs:
459
+ continue
460
+ if len(node.body) != 1 or not isinstance(node.body[0], ast.Return) or node.body[0].value is None:
461
+ continue
462
+ arg_names = [a.arg for a in args.args]
463
+ # Ensure the return expression only references the function's args.
464
+ referenced: set[str] = set()
465
+
466
+ class _Ref(ast.NodeVisitor):
467
+ def visit_Name(self, n: ast.Name) -> None: # noqa: N802
468
+ if isinstance(n.ctx, ast.Load):
469
+ referenced.add(n.id)
470
+
471
+ _Ref().visit(node.body[0].value)
472
+ if not referenced.issubset(set(arg_names)):
473
+ continue
474
+ simple_fns[node.name] = (arg_names, node.body[0].value)
475
+
476
+ if not simple_fns:
477
+ return _unchanged(code=code, meta=meta)
478
+
479
+ class _Substitute(ast.NodeTransformer):
480
+ def __init__(self, mapping: Dict[str, ast.AST]) -> None:
481
+ self.mapping = mapping
482
+
483
+ def visit_Name(self, n: ast.Name) -> ast.AST: # noqa: N802
484
+ if isinstance(n.ctx, ast.Load) and n.id in self.mapping:
485
+ return copy.deepcopy(self.mapping[n.id])
486
+ return n
487
+
488
+ class _Inliner(ast.NodeTransformer):
489
+ def __init__(self) -> None:
490
+ self.changed = False
491
+
492
+ def visit_Call(self, node: ast.Call) -> ast.AST: # noqa: N802
493
+ node = self.generic_visit(node)
494
+ if not isinstance(node.func, ast.Name):
495
+ return node
496
+ fn = simple_fns.get(node.func.id)
497
+ if fn is None:
498
+ return node
499
+ arg_names, expr = fn
500
+ if node.keywords or len(node.args) != len(arg_names):
501
+ return node
502
+ mapping = {name: arg for name, arg in zip(arg_names, node.args, strict=True)}
503
+ new_expr = _Substitute(mapping).visit(copy.deepcopy(expr))
504
+ self.changed = True
505
+ return ast.copy_location(new_expr, node)
506
+
507
+ inliner = _Inliner()
508
+ tree = inliner.visit(tree)
509
+ ast.fix_missing_locations(tree)
510
+ if not inliner.changed:
511
+ return _unchanged(code=code, meta=meta)
512
+
513
+ out = ast.unparse(tree)
514
+ meta["inlined"] = sorted(simple_fns.keys())
515
+ return _finalize_result(original=code, out=out, meta=meta)
516
+ except Exception:
517
+ return _unchanged(code=code, meta=meta)
518
+
acre/datasets/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Datasets and sample code providers for ACRE."""
2
+
3
+ from .code_samples import CodeSample, CodeSampleDataset
4
+
5
+ __all__ = ["CodeSample", "CodeSampleDataset"]
6
+
acre/datasets/code_samples.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Iterable, Iterator, List, Optional
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class CodeSample:
9
+ """A single code sample (placeholder)."""
10
+
11
+ id: str
12
+ language: str
13
+ code: str
14
+
15
+
16
+ class CodeSampleDataset:
17
+ """
18
+ Minimal in-memory dataset stub.
19
+
20
+ Later versions can back this with files, Git repos, or benchmark suites.
21
+ """
22
+
23
+ def __init__(self, samples: Optional[Iterable[CodeSample]] = None) -> None:
24
+ self._samples: List[CodeSample] = list(samples or [])
25
+
26
+ def __len__(self) -> int:
27
+ return len(self._samples)
28
+
29
+ def __iter__(self) -> Iterator[CodeSample]:
30
+ return iter(self._samples)
31
+
32
+ def add(self, sample: CodeSample) -> None:
33
+ self._samples.append(sample)
34
+
acre/demo.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ import sys
6
+ from typing import Any, Optional, Tuple
7
+
8
+ from acre.datasets.code_samples import CodeSample, CodeSampleDataset
9
+ from acre.env.refactor_env import RefactorEnv
10
+
11
+
12
+ def _load_model(path: str):
13
+ """Load a Stable-Baselines3 PPO model if available; otherwise return None."""
14
+ if not os.path.exists(path):
15
+ return None
16
+ try:
17
+ from stable_baselines3 import PPO
18
+ except Exception:
19
+ return None
20
+ try:
21
+ return PPO.load(path)
22
+ except Exception:
23
+ return None
24
+
25
+
26
+ def _messy_sample_code() -> str:
27
+ # Intentionally "messy" but valid Python for demo purposes.
28
+ return (
29
+ "def add(a,b):\n"
30
+ " x=0\n"
31
+ " for i in range(a):\n"
32
+ " x=x+1\n"
33
+ " if True:\n"
34
+ " x = x\n"
35
+ " if False:\n"
36
+ " y=123\n"
37
+ " else:\n"
38
+ " y=0\n"
39
+ " def f(p,q):\n"
40
+ " return p+q\n"
41
+ " r = f(x,y)\n"
42
+ " return r\n"
43
+ )
44
+
45
+
46
+ def _format_code_block(code: str) -> str:
47
+ return "\n".join(f" {line}" for line in code.rstrip().splitlines()) + "\n"
48
+
49
+
50
+ def _safe_print(text: str) -> None:
51
+ """
52
+ Print text safely across Windows consoles (some default encodings can't print emojis).
53
+ """
54
+ encoding = sys.stdout.encoding or "utf-8"
55
+ try:
56
+ text.encode(encoding)
57
+ print(text, flush=True)
58
+ except Exception:
59
+ # Fall back to ASCII-friendly markers if emojis can't be encoded.
60
+ safe = text.replace("✅", "[OK]").replace("⚠️", "[WARN]").replace("⚠", "[WARN]")
61
+ print(safe, flush=True)
62
+
63
+
64
+ def _compute_runtime(executor: Any, code: str) -> float:
65
+ """Best-effort runtime metric using the current executor contract."""
66
+ try:
67
+ res = executor.run(code, filename="demo.py")
68
+ if getattr(res, "exit_code", 1) == 0 and isinstance(getattr(res, "metrics", None), dict):
69
+ return float(res.metrics.get("runtime_s", 0.0) or 0.0)
70
+ except Exception:
71
+ pass
72
+ return 0.0
73
+
74
+
75
+ def _choose_action(model: Any, obs, env: RefactorEnv, rng: random.Random) -> Tuple[int, str]:
76
+ """Choose an action from the model, falling back to random."""
77
+ n_actions = int(getattr(getattr(env, "action_space", None), "n", 5))
78
+ if model is None:
79
+ a = int(rng.randint(0, n_actions - 1))
80
+ return a, "random"
81
+
82
+ try:
83
+ action, _state = model.predict(obs, deterministic=True)
84
+ # SB3 may return scalar or 1-element array.
85
+ if hasattr(action, "__len__"):
86
+ a = int(action[0])
87
+ else:
88
+ a = int(action)
89
+ return a, "ppo"
90
+ except Exception:
91
+ a = int(rng.randint(0, n_actions - 1))
92
+ return a, "random"
93
+
94
+
95
+ def run_demo(*, model_path: str = "acre_agent.zip", seed: int = 0) -> None:
96
+ rng = random.Random(seed)
97
+
98
+ # Create a dataset with one messy sample so `reset()` loads it deterministically.
99
+ dataset = CodeSampleDataset(
100
+ [
101
+ CodeSample(
102
+ id="demo_sample",
103
+ language="python",
104
+ code=_messy_sample_code(),
105
+ )
106
+ ]
107
+ )
108
+ env = RefactorEnv(dataset=dataset, seed=seed)
109
+
110
+ model = _load_model(model_path)
111
+ model_status = "loaded" if model is not None else "not found (using random actions)"
112
+
113
+ # Reset and capture the original code/metrics.
114
+ obs, info = env.reset()
115
+ original_code = getattr(env, "_code", "")
116
+ original_complexity = float(getattr(env, "_compute_complexity")(original_code))
117
+ original_runtime = _compute_runtime(env.executor, original_code)
118
+
119
+ print("=" * 72)
120
+ print("ACRE: Autonomous RL Code Refactoring Agent (5-step episode)")
121
+ print(f"Model: {model_path} -> {model_status}")
122
+ print(f"Sample: {info.get('sample_id')} ({info.get('language')})")
123
+ print("=" * 72)
124
+ print("\nORIGINAL CODE:\n")
125
+ print(_format_code_block(original_code))
126
+
127
+ total_reward = 0.0
128
+ successful_transformations = 0
129
+ steps_taken = 0
130
+
131
+ for step_idx in range(1, 6):
132
+ action, policy = _choose_action(model, obs, env, rng)
133
+ obs, reward, terminated, truncated, step_info = env.step(action)
134
+ total_reward += float(reward)
135
+ steps_taken = step_idx
136
+
137
+ action_name = step_info.get("action_name", "unknown")
138
+ transform_meta = step_info.get("transform", {})
139
+ if isinstance(transform_meta, dict) and bool(transform_meta.get("success", False)):
140
+ successful_transformations += 1
141
+ transformed_code = getattr(env, "_code", "")
142
+
143
+ print("-" * 72)
144
+ print(f"STEP {step_idx}/5")
145
+ print(f"policy={policy} action={action} ({action_name})")
146
+ print(f"transform={transform_meta}")
147
+ print(f"reward={float(reward):.2f} components={step_info.get('reward_components')}")
148
+ print("\nUPDATED CODE:\n")
149
+ print(_format_code_block(transformed_code))
150
+
151
+ if terminated or truncated:
152
+ break
153
+
154
+ final_code = getattr(env, "_code", "")
155
+ final_complexity = float(getattr(env, "_compute_complexity")(final_code))
156
+ final_runtime = _compute_runtime(env.executor, final_code)
157
+
158
+ print("=" * 72)
159
+ print("FINAL SUMMARY")
160
+ print("=" * 72)
161
+ print(f"total_reward: {total_reward:.2f}")
162
+ print(f"complexity: {original_complexity:.0f} -> {final_complexity:.0f}")
163
+ print(f"runtime_s: {original_runtime:.4f} -> {final_runtime:.4f}")
164
+
165
+ complexity_improvement = ((original_complexity - final_complexity) / max(original_complexity, 1.0)) * 100.0
166
+ print(f"complexity improvement: {complexity_improvement:.2f}%")
167
+
168
+ print("\nCHANGES APPLIED:")
169
+ print(f"- Total steps: {steps_taken}")
170
+ print(f"- Successful transformations: {successful_transformations}")
171
+
172
+ if total_reward > 0:
173
+ _safe_print("\n✅ Code improved successfully")
174
+ else:
175
+ _safe_print("\n⚠️ No significant improvement")
176
+
177
+ print("\nFINAL CODE:\n")
178
+ print(_format_code_block(final_code))
179
+
180
+ env.close()
181
+
182
+
183
+ if __name__ == "__main__":
184
+ run_demo()
185
+
acre/main.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+
5
+ from acre.training.train_agent import TrainConfig, train
6
+
7
+
8
+ def _build_parser() -> argparse.ArgumentParser:
9
+ parser = argparse.ArgumentParser(prog="acre", description="ACRE: Autonomous Code Refactoring Environment")
10
+ sub = parser.add_subparsers(dest="command", required=False)
11
+
12
+ train_p = sub.add_parser("train", help="Run training (stub)")
13
+ train_p.add_argument("--total-steps", type=int, default=100, help="Total training steps (stub)")
14
+
15
+ sub.add_parser("demo", help="Run a small demo (stub)")
16
+
17
+ return parser
18
+
19
+
20
+ def run_demo() -> None:
21
+ # Placeholder for a future interactive/demo flow.
22
+ print("ACRE demo mode is not implemented yet.")
23
+
24
+
25
+ def main(argv: list[str] | None = None) -> None:
26
+ parser = _build_parser()
27
+ args = parser.parse_args(argv)
28
+
29
+ if args.command == "demo":
30
+ run_demo()
31
+ return
32
+
33
+ total_steps = getattr(args, "total_steps", 100)
34
+ train(config=TrainConfig(total_steps=total_steps))
35
+
36
+
37
+ if __name__ == "__main__":
38
+ main()
39
+
acre/tasks/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from acre.tasks.task_registry import Task, TaskRegistry
2
+
3
+ __all__ = ["Task", "TaskRegistry"]
acre/tasks/task_registry.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Three OpenEnv tasks with AST-based graders scoring 0.0-1.0.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import ast
7
+ from dataclasses import dataclass
8
+ from typing import Callable, Dict, List, Optional
9
+
10
+
11
+ @dataclass
12
+ class Task:
13
+ id: str
14
+ name: str
15
+ description: str
16
+ difficulty: str
17
+ initial_code: str
18
+ _grade_fn: Callable[[str], float]
19
+
20
+ def grade(self, code: str) -> float:
21
+ """Return a score in [0.0, 1.0]."""
22
+ try:
23
+ return float(min(1.0, max(0.0, self._grade_fn(code))))
24
+ except Exception:
25
+ return 0.0
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Task 1 — Easy: Rename generic variables
30
+ # ---------------------------------------------------------------------------
31
+ _EASY_CODE = """\
32
+ def compute(x, y, tmp):
33
+ tmp = x + y
34
+ x = tmp * 2
35
+ result = x
36
+ return result
37
+ """
38
+
39
+
40
+ def _grade_easy(code: str) -> float:
41
+ """Score = fraction of generic names (x, tmp) removed from all scopes."""
42
+ generic = {"x", "tmp"}
43
+ try:
44
+ tree = ast.parse(code)
45
+ except SyntaxError:
46
+ return 0.0
47
+
48
+ remaining: set[str] = set()
49
+
50
+ class _Collector(ast.NodeVisitor):
51
+ def visit_Name(self, node: ast.Name) -> None:
52
+ if node.id in generic:
53
+ remaining.add(node.id)
54
+ self.generic_visit(node)
55
+
56
+ def visit_arg(self, node: ast.arg) -> None:
57
+ if node.arg in generic:
58
+ remaining.add(node.arg)
59
+ self.generic_visit(node)
60
+
61
+ _Collector().visit(tree)
62
+ renamed = len(generic - remaining)
63
+ return renamed / len(generic)
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Task 2 — Medium: Remove dead code
68
+ # ---------------------------------------------------------------------------
69
+ _MEDIUM_CODE = """\
70
+ def process(data):
71
+ result = []
72
+ for item in data:
73
+ result.append(item * 2)
74
+ if False:
75
+ print("never runs")
76
+ unused_var = 42
77
+ return result
78
+ print("unreachable")
79
+ """
80
+
81
+
82
+ def _grade_medium(code: str) -> float:
83
+ """Score = fraction of dead-code patterns eliminated (3 checks, ~0.33 each)."""
84
+ try:
85
+ tree = ast.parse(code)
86
+ except SyntaxError:
87
+ return 0.0
88
+
89
+ source = ast.unparse(tree)
90
+ score = 0.0
91
+
92
+ # Check 1: if-False block removed
93
+ if "if False" not in source:
94
+ score += 1 / 3
95
+
96
+ # Check 2: unused_var assignment removed
97
+ if "unused_var" not in source:
98
+ score += 1 / 3
99
+
100
+ # Check 3: list comprehension used (loop simplified)
101
+ has_listcomp = any(isinstance(n, ast.ListComp) for n in ast.walk(tree))
102
+ if has_listcomp:
103
+ score += 1 / 3
104
+
105
+ return score
106
+
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # Task 3 — Hard: Full refactor
110
+ # ---------------------------------------------------------------------------
111
+ _HARD_CODE = """\
112
+ def add(p, q):
113
+ return p + q
114
+
115
+ def compute(x, data, tmp):
116
+ result = []
117
+ for item in data:
118
+ result.append(item * 2)
119
+ if False:
120
+ y = 999
121
+ if True:
122
+ val = add(x, tmp)
123
+ unused = 0
124
+ flag = not not True
125
+ return val
126
+ print("dead")
127
+ """
128
+
129
+
130
+ def _grade_hard(code: str) -> float:
131
+ """Score = fraction of 5 quality checks passed."""
132
+ try:
133
+ tree = ast.parse(code)
134
+ except SyntaxError:
135
+ return 0.0
136
+
137
+ source = ast.unparse(tree)
138
+ checks = 0
139
+
140
+ # 1. No generic variable names x/tmp in function signature or body
141
+ has_generic = False
142
+
143
+ class _GenCheck(ast.NodeVisitor):
144
+ def visit_arg(self, node: ast.arg) -> None:
145
+ nonlocal has_generic
146
+ if node.arg in {"x", "tmp"}:
147
+ has_generic = True
148
+
149
+ _GenCheck().visit(tree)
150
+ if not has_generic:
151
+ checks += 1
152
+
153
+ # 2. No if False block
154
+ if "if False" not in source:
155
+ checks += 1
156
+
157
+ # 3. if True removed (body inlined)
158
+ if "if True" not in source:
159
+ checks += 1
160
+
161
+ # 4. List comprehension used
162
+ if any(isinstance(n, ast.ListComp) for n in ast.walk(tree)):
163
+ checks += 1
164
+
165
+ # 5. add() call inlined (no call to 'add')
166
+ calls = [n for n in ast.walk(tree) if isinstance(n, ast.Call)]
167
+ fn_names = {c.func.id for c in calls if isinstance(c.func, ast.Name)}
168
+ if "add" not in fn_names:
169
+ checks += 1
170
+
171
+ return checks / 5
172
+
173
+
174
+ # ---------------------------------------------------------------------------
175
+ # Registry
176
+ # ---------------------------------------------------------------------------
177
+
178
+ class TaskRegistry:
179
+ def __init__(self) -> None:
180
+ self._tasks: Dict[str, Task] = {}
181
+ self._register_all()
182
+
183
+ def _register_all(self) -> None:
184
+ self._tasks["rename_variables"] = Task(
185
+ id="rename_variables",
186
+ name="Rename Variables (Easy)",
187
+ description="Rename generic variable names (x, tmp) to descriptive ones",
188
+ difficulty="easy",
189
+ initial_code=_EASY_CODE,
190
+ _grade_fn=_grade_easy,
191
+ )
192
+ self._tasks["remove_dead_code"] = Task(
193
+ id="remove_dead_code",
194
+ name="Remove Dead Code (Medium)",
195
+ description="Remove unreachable code, if False blocks, and unused variables",
196
+ difficulty="medium",
197
+ initial_code=_MEDIUM_CODE,
198
+ _grade_fn=_grade_medium,
199
+ )
200
+ self._tasks["full_refactor"] = Task(
201
+ id="full_refactor",
202
+ name="Full Refactor (Hard)",
203
+ description="Apply all transformations: rename, dead code, loops, conditions, inlining",
204
+ difficulty="hard",
205
+ initial_code=_HARD_CODE,
206
+ _grade_fn=_grade_hard,
207
+ )
208
+
209
+ def get_task(self, task_id: str) -> Optional[Task]:
210
+ return self._tasks.get(task_id)
211
+
212
+ def list_tasks(self) -> List[dict]:
213
+ return [
214
+ {
215
+ "id": t.id,
216
+ "name": t.name,
217
+ "description": t.description,
218
+ "difficulty": t.difficulty,
219
+ "initial_code": t.initial_code,
220
+ }
221
+ for t in self._tasks.values()
222
+ ]
acre/training/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Training utilities for ACRE."""
2
+
3
+ from .train_agent import TrainConfig, train
4
+
5
+ __all__ = ["TrainConfig", "train"]
6
+
acre/training/train_agent.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ from acre.env.refactor_env import RefactorEnv
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class TrainConfig:
11
+ """Configuration stub for training."""
12
+
13
+ total_steps: int = 5_000
14
+ seed: Optional[int] = None
15
+ model_path: str = "acre_agent.zip"
16
+
17
+
18
+ def train(*, env: Optional[RefactorEnv] = None, config: Optional[TrainConfig] = None) -> None:
19
+ """
20
+ Train a PPO agent on `RefactorEnv` using Stable-Baselines3.
21
+
22
+ This is intentionally lightweight (hackathon-friendly) and focuses on a
23
+ working demo: basic training loop, simple logging, and saving the model.
24
+ """
25
+ _config = config or TrainConfig()
26
+ _env = env or RefactorEnv(seed=_config.seed)
27
+
28
+ try:
29
+ from stable_baselines3 import PPO
30
+ from stable_baselines3.common.callbacks import BaseCallback
31
+ from stable_baselines3.common.monitor import Monitor
32
+ from stable_baselines3.common.vec_env import DummyVecEnv
33
+ except Exception as e: # pragma: no cover
34
+ print("Stable-Baselines3 is required for training. Install with `pip install -r requirements.txt`.")
35
+ print(f"Import error: {e}")
36
+ return None
37
+
38
+ class EpisodeRewardPrinter(BaseCallback):
39
+ """Print episode reward when an episode ends (via Monitor)."""
40
+
41
+ def __init__(self) -> None:
42
+ super().__init__()
43
+ self.episode_count = 0
44
+
45
+ def _on_step(self) -> bool:
46
+ infos = self.locals.get("infos", [])
47
+ for info in infos:
48
+ ep = info.get("episode") if isinstance(info, dict) else None
49
+ if isinstance(ep, dict) and "r" in ep:
50
+ self.episode_count += 1
51
+ print(f"episode={self.episode_count} reward={ep['r']:.2f} length={int(ep.get('l', 0))}")
52
+ return True
53
+
54
+ # Wrap with Monitor so SB3 can compute episode stats and expose them in `info["episode"]`.
55
+ def make_env() -> RefactorEnv:
56
+ return Monitor(_env)
57
+
58
+ vec_env = DummyVecEnv([make_env])
59
+
60
+ model = PPO(
61
+ policy="MlpPolicy",
62
+ env=vec_env,
63
+ verbose=0,
64
+ seed=_config.seed,
65
+ n_steps=64,
66
+ batch_size=64,
67
+ )
68
+
69
+ print(f"Training PPO for {int(_config.total_steps)} timesteps...")
70
+ model.learn(total_timesteps=int(_config.total_steps), callback=EpisodeRewardPrinter())
71
+
72
+ model.save(_config.model_path)
73
+ print(f"Saved model to {_config.model_path!r}")
74
+ return None
75
+
acre/utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Shared utility helpers for ACRE."""
2
+
3
+ from .metrics import Metric, MetricLogger
4
+
5
+ __all__ = ["Metric", "MetricLogger"]
6
+
acre/utils/metrics.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, Iterable, List, Tuple
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class Metric:
9
+ """Single scalar metric value (placeholder)."""
10
+
11
+ name: str
12
+ value: float
13
+
14
+
15
+ @dataclass
16
+ class MetricLogger:
17
+ """Tiny metric logger stub."""
18
+
19
+ _history: Dict[str, List[float]] = field(default_factory=dict)
20
+
21
+ def log(self, metric: Metric) -> None:
22
+ self._history.setdefault(metric.name, []).append(metric.value)
23
+
24
+ def latest(self) -> Dict[str, float]:
25
+ return {k: v[-1] for k, v in self._history.items() if v}
26
+
27
+ def as_series(self) -> Dict[str, Tuple[float, ...]]:
28
+ return {k: tuple(v) for k, v in self._history.items()}
29
+
30
+ def extend(self, metrics: Iterable[Metric]) -> None:
31
+ for m in metrics:
32
+ self.log(m)
33
+
inference.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACRE inference script for OpenEnv submission evaluation.
3
+
4
+ Required environment variables:
5
+ API_BASE_URL: LLM API endpoint (default allowed)
6
+ MODEL_NAME: model identifier (default allowed)
7
+ HF_TOKEN: API token for the OpenAI-compatible endpoint
8
+ ENV_URL: running ACRE server base URL
9
+
10
+ Optional:
11
+ LOCAL_IMAGE_NAME: present for evaluator compatibility when using a local
12
+ Docker image launcher.
13
+
14
+ Stdout format uses strict START / STEP / END event markers.
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import os
20
+ import re
21
+ import sys
22
+ import time
23
+ from typing import Dict, List, Tuple
24
+
25
+ import requests
26
+ from openai import OpenAI
27
+
28
+ API_BASE_URL: str = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
29
+ MODEL_NAME: str = os.getenv("MODEL_NAME", "gpt-4o-mini")
30
+ HF_TOKEN: str | None = os.getenv("HF_TOKEN")
31
+ ENV_URL: str | None = os.getenv("ENV_URL")
32
+ LOCAL_IMAGE_NAME: str | None = os.getenv("LOCAL_IMAGE_NAME")
33
+
34
+ TASKS: List[str] = ["rename_variables", "remove_dead_code", "full_refactor"]
35
+
36
+ ACTION_MEANINGS: Dict[int, str] = {
37
+ 0: "rename_variable",
38
+ 1: "remove_dead_code",
39
+ 2: "simplify_loop",
40
+ 3: "optimize_condition",
41
+ 4: "inline_function",
42
+ }
43
+
44
+ SYSTEM_PROMPT = """\
45
+ You are an RL agent that refactors Python code. Choose one action per step.
46
+
47
+ Actions:
48
+ 0 rename_variable - rename generic names (x, tmp, i) to descriptive ones
49
+ 1 remove_dead_code - remove unreachable stmts, if False blocks, unused vars
50
+ 2 simplify_loop - convert append-loops to list comprehensions
51
+ 3 optimize_condition- simplify 'not not x', 'if True/False', 'x==True'
52
+ 4 inline_function - inline simple single-return module-level functions
53
+
54
+ Respond ONLY with valid JSON (no markdown):
55
+ {"action": <0-4>, "reason": "<one sentence>"}"""
56
+
57
+
58
+ def _env_url() -> str:
59
+ if ENV_URL:
60
+ return ENV_URL.rstrip("/")
61
+ raise RuntimeError("ENV_URL must be set before running inference.py")
62
+
63
+
64
+ def _post(path: str, payload: dict | None = None) -> dict:
65
+ response = requests.post(f"{_env_url()}{path}", json=payload or {}, timeout=30)
66
+ response.raise_for_status()
67
+ return response.json()
68
+
69
+
70
+ def _get(path: str) -> dict:
71
+ response = requests.get(f"{_env_url()}{path}", timeout=30)
72
+ response.raise_for_status()
73
+ return response.json()
74
+
75
+
76
+ def reset_env(task_id: str) -> dict:
77
+ return _post("/reset", {"task_id": task_id})
78
+
79
+
80
+ def step_env(action: int) -> dict:
81
+ return _post("/step", {"action": action})
82
+
83
+
84
+ def get_state() -> dict:
85
+ return _get("/state")
86
+
87
+
88
+ def grade(task_id: str, code: str) -> float:
89
+ response = requests.post(
90
+ f"{_env_url()}/tasks/{task_id}/grade",
91
+ json={"code": code},
92
+ timeout=30,
93
+ )
94
+ response.raise_for_status()
95
+ return float(response.json().get("score", 0.0))
96
+
97
+
98
+ def choose_action(client: OpenAI, state: dict, task_id: str) -> Tuple[int, str]:
99
+ def heuristic_action() -> Tuple[int, str]:
100
+ code = str(state.get("current_code", ""))
101
+ step_i = int(state.get("episode_steps", 0))
102
+
103
+ has_generic = re.search(r"\b(x|tmp|i)\b", code) is not None
104
+ has_if_false = re.search(r"\bif\s+False\b", code) is not None
105
+ has_if_true = re.search(r"\bif\s+True\b", code) is not None
106
+ has_append_loop = ".append(" in code and "for " in code
107
+ has_double_not = "not not" in code
108
+ has_add_call = "add(" in code
109
+
110
+ if task_id == "rename_variables":
111
+ if has_generic:
112
+ return 0, "heuristic: remove generic names first"
113
+ if has_if_false or "unused" in code:
114
+ return 1, "heuristic: remove dead code"
115
+ if has_append_loop:
116
+ return 2, "heuristic: simplify loop"
117
+ if has_if_true or has_double_not:
118
+ return 3, "heuristic: optimize conditions"
119
+ return 4, "heuristic: inline simple function"
120
+
121
+ if task_id == "remove_dead_code":
122
+ if has_if_false or "unused" in code:
123
+ return 1, "heuristic: remove dead code patterns"
124
+ if has_append_loop:
125
+ return 2, "heuristic: convert append-loop"
126
+ if has_if_true or has_double_not:
127
+ return 3, "heuristic: simplify conditions"
128
+ if has_generic:
129
+ return 0, "heuristic: clean generic names"
130
+ return 4, "heuristic: inline helper"
131
+
132
+ if has_generic:
133
+ return 0, "heuristic: rename generic variables"
134
+ if has_append_loop:
135
+ return 2, "heuristic: simplify loop into listcomp"
136
+ if has_if_false or has_if_true or has_double_not:
137
+ return 3, "heuristic: optimize boolean branches"
138
+ if has_add_call:
139
+ return 4, "heuristic: inline add() call"
140
+ if step_i >= 2:
141
+ return 1, "heuristic: remove remaining dead code"
142
+ return 3, "heuristic: condition optimization as safe default"
143
+
144
+ if not HF_TOKEN:
145
+ return heuristic_action()
146
+
147
+ messages = [
148
+ {"role": "system", "content": SYSTEM_PROMPT},
149
+ {
150
+ "role": "user",
151
+ "content": (
152
+ f"Task: {task_id}\n"
153
+ f"Steps remaining: {state.get('max_steps', 5) - state.get('episode_steps', 0)}\n"
154
+ f"Complexity: {state.get('complexity', 0)}\n\n"
155
+ f"Current code:\n```python\n{state.get('current_code', '')}\n```\n\n"
156
+ "Choose the best action."
157
+ ),
158
+ },
159
+ ]
160
+ try:
161
+ response = client.chat.completions.create(
162
+ model=MODEL_NAME,
163
+ messages=messages,
164
+ temperature=0.0,
165
+ max_tokens=120,
166
+ )
167
+ raw = (response.choices[0].message.content or "").strip()
168
+ json_blob = raw
169
+
170
+ if "{" not in json_blob or "}" not in json_blob:
171
+ return heuristic_action()
172
+
173
+ match = re.search(r"\{.*\}", json_blob, flags=re.DOTALL)
174
+ if match:
175
+ json_blob = match.group(0)
176
+
177
+ parsed = json.loads(json_blob)
178
+ action = int(parsed.get("action", -1))
179
+ reason = str(parsed.get("reason", ""))
180
+ if 0 <= action <= 4:
181
+ return action, reason or "llm-selected action"
182
+ return heuristic_action()
183
+ except Exception:
184
+ return heuristic_action()
185
+
186
+
187
+ def run_episode(client: OpenAI, task_id: str, episode_num: int) -> float:
188
+ reset_env(task_id)
189
+ state = get_state()
190
+
191
+ print(
192
+ json.dumps(
193
+ {
194
+ "event": "START",
195
+ "episode": episode_num,
196
+ "task_id": task_id,
197
+ "initial_complexity": state.get("complexity", 0),
198
+ "initial_code_length": len(state.get("current_code", "")),
199
+ "timestamp": time.time(),
200
+ }
201
+ ),
202
+ flush=True,
203
+ )
204
+
205
+ cumulative_reward = 0.0
206
+
207
+ for step_num in range(1, 6):
208
+ action, reason = choose_action(client, state, task_id)
209
+ result = step_env(action)
210
+ state = get_state()
211
+
212
+ reward_payload = result.get("reward", {})
213
+ raw_reward = float(reward_payload.get("raw", 0.0))
214
+ norm_reward = float(reward_payload.get("normalized", (raw_reward + 32) / 52))
215
+ cumulative_reward += raw_reward
216
+
217
+ print(
218
+ json.dumps(
219
+ {
220
+ "event": "STEP",
221
+ "episode": episode_num,
222
+ "step": step_num,
223
+ "action": action,
224
+ "action_name": ACTION_MEANINGS.get(action, "unknown"),
225
+ "reason": reason,
226
+ "reward": round(raw_reward, 4),
227
+ "normalized_reward": round(norm_reward, 4),
228
+ "cumulative_reward": round(cumulative_reward, 4),
229
+ "changed": result.get("info", {}).get("changed", False),
230
+ "reward_components": reward_payload.get("components", {}),
231
+ "done": result.get("done", False),
232
+ }
233
+ ),
234
+ flush=True,
235
+ )
236
+
237
+ if result.get("done") or result.get("terminated") or result.get("truncated"):
238
+ break
239
+
240
+ final_state = get_state()
241
+ task_score = grade(task_id, final_state.get("current_code", ""))
242
+
243
+ print(
244
+ json.dumps(
245
+ {
246
+ "event": "END",
247
+ "episode": episode_num,
248
+ "task_id": task_id,
249
+ "cumulative_reward": round(cumulative_reward, 4),
250
+ "normalized_cumulative": round((cumulative_reward + 32) / 52, 4),
251
+ "task_score": round(task_score, 4),
252
+ "final_complexity": final_state.get("complexity", 0),
253
+ "timestamp": time.time(),
254
+ }
255
+ ),
256
+ flush=True,
257
+ )
258
+
259
+ return task_score
260
+
261
+
262
+ def main() -> None:
263
+ if not ENV_URL:
264
+ raise SystemExit("ENV_URL is required. Example: ENV_URL=http://localhost:7860")
265
+
266
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "dummy")
267
+
268
+ scores: List[float] = []
269
+ for i, task_id in enumerate(TASKS, start=1):
270
+ score = run_episode(client, task_id, i)
271
+ scores.append(score)
272
+
273
+ avg_score = sum(scores) / len(scores) if scores else 0.0
274
+ sys.exit(0 if avg_score >= 0.5 else 1)
275
+
276
+
277
+ if __name__ == "__main__":
278
+ main()
models.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional, Sequence
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class ObservationModel(BaseModel):
9
+ code_length: float
10
+ complexity_score: float
11
+ runtime_s: float
12
+ error_flag: bool
13
+
14
+ @classmethod
15
+ def from_vector(cls, values: Sequence[float]) -> "ObservationModel":
16
+ vector = list(values)
17
+ if len(vector) != 4:
18
+ raise ValueError(f"observation vector must have length 4, got {len(vector)}")
19
+ return cls(
20
+ code_length=float(vector[0]),
21
+ complexity_score=float(vector[1]),
22
+ runtime_s=float(vector[2]),
23
+ error_flag=bool(vector[3]),
24
+ )
25
+
26
+ def to_vector(self) -> List[float]:
27
+ return [
28
+ float(self.code_length),
29
+ float(self.complexity_score),
30
+ float(self.runtime_s),
31
+ float(int(self.error_flag)),
32
+ ]
33
+
34
+
35
+ class ActionModel(BaseModel):
36
+ action: int = Field(ge=0, le=4)
37
+ action_name: Optional[str] = None
38
+
39
+
40
+ class RewardModel(BaseModel):
41
+ raw: float
42
+ normalized: float = Field(ge=0.0, le=1.0)
43
+ components: Dict[str, float]
44
+
45
+
46
+ class HealthResponse(BaseModel):
47
+ status: str
48
+ env: str
49
+ version: str
50
+
51
+
52
+ class CompatibilityHealthResponse(BaseModel):
53
+ status: str
54
+ service: str
55
+
56
+
57
+ class ResetRequest(BaseModel):
58
+ task_id: Optional[str] = None
59
+ seed: Optional[int] = None
60
+ code: Optional[str] = None
61
+
62
+
63
+ class StepRequest(BaseModel):
64
+ action: int = Field(ge=0, le=4)
65
+
66
+
67
+ class GradeRequest(BaseModel):
68
+ code: str
69
+
70
+
71
+ class TaskInfo(BaseModel):
72
+ id: str
73
+ name: str
74
+ description: str
75
+ difficulty: str
76
+ initial_code: str
77
+
78
+
79
+ class TasksResponse(BaseModel):
80
+ tasks: List[TaskInfo]
81
+
82
+
83
+ class GradeResponse(BaseModel):
84
+ task_id: str
85
+ score: float
86
+ passed: bool
87
+
88
+
89
+ class StateResponse(BaseModel):
90
+ current_code: str
91
+ episode_steps: int
92
+ max_steps: int
93
+ complexity: float
94
+ last_runtime: float
95
+ last_error: bool
96
+ sample_id: Optional[str]
97
+ language: Optional[str]
98
+ task_id: Optional[str]
99
+ observation: ObservationModel
100
+ observation_vector: List[float]
101
+ action_meanings: Dict[int, str]
102
+
103
+
104
+ class ResetResponse(BaseModel):
105
+ observation: ObservationModel
106
+ observation_vector: List[float]
107
+ info: Dict[str, Any]
108
+ task_id: Optional[str]
109
+ state: StateResponse
110
+
111
+
112
+ class StepResponse(BaseModel):
113
+ action: ActionModel
114
+ observation: ObservationModel
115
+ observation_vector: List[float]
116
+ reward: RewardModel
117
+ done: bool
118
+ terminated: bool
119
+ truncated: bool
120
+ info: Dict[str, Any]
121
+ state: StateResponse
122
+
123
+
124
+ class OptimizeRequest(BaseModel):
125
+ code: str
126
+ task_id: Optional[str] = None
127
+ max_steps: int = Field(default=5, ge=1, le=5)
128
+ use_rl: bool = True
129
+ use_llm: bool = False
130
+ fallback_to_llm: bool = True
131
+ rl_model_path: Optional[str] = None
132
+ api_base_url: Optional[str] = None
133
+ model_name: Optional[str] = None
134
+ api_token: Optional[str] = None
135
+
136
+
137
+ class OptimizationStep(BaseModel):
138
+ step: int
139
+ action: int
140
+ action_name: str
141
+ reason: str
142
+ source: str
143
+ reward: float
144
+ normalized_reward: float
145
+ changed: bool
146
+ complexity: float
147
+
148
+
149
+ class OptimizeResponse(BaseModel):
150
+ original_code: str
151
+ optimized_code: str
152
+ diff: str
153
+ steps: List[OptimizationStep]
154
+ cumulative_reward: float
155
+ task_id: Optional[str]
156
+ task_score: Optional[float]
openenv.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ACRE
2
+ version: "1.0.0"
3
+ description: >
4
+ Autonomous Code Refactoring Environment - an RL environment where an
5
+ agent improves Python code quality using AST-level transformations.
6
+ author: "Nikhil Pratap Singh, Pranav Mangal, Ananya Gupta"
7
+ entrypoint: "openenv_interface:OpenEnvRefactorEnv"
8
+ tags:
9
+ - openenv
10
+
11
+ tasks:
12
+ - id: rename_variables
13
+ name: "Rename Variables (Easy)"
14
+ description: "Rename generic variable names (x, tmp) to descriptive ones"
15
+ difficulty: easy
16
+ reward_range: [0.0, 1.0]
17
+ max_steps: 5
18
+
19
+ - id: remove_dead_code
20
+ name: "Remove Dead Code (Medium)"
21
+ description: "Remove unreachable statements, if-False blocks, and unused assignments"
22
+ difficulty: medium
23
+ reward_range: [0.0, 1.0]
24
+ max_steps: 5
25
+
26
+ - id: full_refactor
27
+ name: "Full Refactor (Hard)"
28
+ description: "Apply all transformations - rename, dead code removal, loop simplification, condition optimization, and function inlining"
29
+ difficulty: hard
30
+ reward_range: [0.0, 1.0]
31
+ max_steps: 5
32
+
33
+ observation_space:
34
+ type: Box
35
+ shape: [4]
36
+ dtype: float32
37
+ low: [0.0, 0.0, 0.0, 0.0]
38
+ high: [inf, inf, inf, 1.0]
39
+ fields:
40
+ - code_length
41
+ - complexity_score
42
+ - runtime_s
43
+ - error_flag
44
+
45
+ action_space:
46
+ type: Discrete
47
+ n: 5
48
+ actions:
49
+ 0: rename_variable
50
+ 1: remove_dead_code
51
+ 2: simplify_loop
52
+ 3: optimize_condition
53
+ 4: inline_function
54
+
55
+ api:
56
+ health: "GET /"
57
+ reset: "POST /reset"
58
+ step: "POST /step"
59
+ state: "GET /state"
60
+ tasks: "GET /tasks"
61
+ grade: "POST /tasks/{task_id}/grade"
62
+
63
+ reward:
64
+ raw_range: [-32, 20]
65
+ normalized_range: [0.0, 1.0]
66
+ formula: "(raw + 32) / 52"
67
+ components:
68
+ success: { max: 10, min: -10 }
69
+ complexity: { max: 5, min: -5 }
70
+ performance: { max: 5, min: -2 }
71
+ error: { max: 0, min: -15 }
72
+ no_change: { max: 0, min: -2 }
73
+
74
+ validation:
75
+ python_api:
76
+ reset: "ObservationModel"
77
+ step: "(ObservationModel, RewardModel, done, info)"
78
+ state: "StateResponse"
79
+ http_api:
80
+ health: "GET /"
81
+ reset: "POST /reset"
82
+ step: "POST /step"
83
+ state: "GET /state"
84
+ tasks: "GET /tasks"
85
+ grade: "POST /tasks/{task_id}/grade"
openenv_interface.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional, Tuple
4
+
5
+ try:
6
+ from openenv.env import Env as OpenEnvBase
7
+ except Exception: # pragma: no cover
8
+ class OpenEnvBase:
9
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
10
+ return None
11
+
12
+ from acre.datasets.code_samples import CodeSample, CodeSampleDataset
13
+ from acre.env.refactor_env import RefactorEnv
14
+ from acre.tasks.task_registry import TaskRegistry
15
+ from models import ActionModel, ObservationModel, RewardModel, StateResponse
16
+
17
+
18
+ class OpenEnvRefactorEnv(OpenEnvBase):
19
+ """
20
+ Canonical OpenEnv interface for ACRE.
21
+
22
+ This wrapper keeps the strict hackathon contract:
23
+ - reset() -> ObservationModel
24
+ - step(action) -> (ObservationModel, RewardModel, done, info)
25
+ - state() -> StateResponse
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ *,
31
+ env: Optional[RefactorEnv] = None,
32
+ registry: Optional[TaskRegistry] = None,
33
+ ) -> None:
34
+ super().__init__(
35
+ name="ACRE",
36
+ state_space="ObservationModel",
37
+ action_space="ActionModel",
38
+ episode_max_length=RefactorEnv.MAX_STEPS,
39
+ )
40
+ self._env = env or RefactorEnv()
41
+ self._registry = registry or TaskRegistry()
42
+ self._task_id: Optional[str] = None
43
+ self._last_reset_info: Dict[str, Any] = {}
44
+
45
+ @property
46
+ def action_meanings(self) -> Dict[int, str]:
47
+ return self._env.ACTION_MEANINGS
48
+
49
+ @property
50
+ def last_reset_info(self) -> Dict[str, Any]:
51
+ return dict(self._last_reset_info)
52
+
53
+ def _load_episode_source(self, *, task_id: Optional[str], code: Optional[str]) -> None:
54
+ initial_code = code
55
+ if initial_code is None and task_id:
56
+ task = self._registry.get_task(task_id)
57
+ if task is None:
58
+ raise ValueError(f"Task '{task_id}' not found")
59
+ initial_code = task.initial_code
60
+
61
+ if initial_code is None:
62
+ return None
63
+
64
+ self._env.dataset = CodeSampleDataset(
65
+ [
66
+ CodeSample(
67
+ id=task_id or "custom",
68
+ language="python",
69
+ code=initial_code,
70
+ )
71
+ ]
72
+ )
73
+ return None
74
+
75
+ def reset(
76
+ self,
77
+ *,
78
+ seed: Optional[int] = None,
79
+ task_id: Optional[str] = None,
80
+ code: Optional[str] = None,
81
+ ) -> ObservationModel:
82
+ self._task_id = task_id
83
+ self._load_episode_source(task_id=task_id, code=code)
84
+ observation, info = self._env.reset(seed=seed)
85
+ self._last_reset_info = dict(info)
86
+ return ObservationModel.from_vector(observation.tolist())
87
+
88
+ def step(self, action: int | ActionModel) -> Tuple[ObservationModel, RewardModel, bool, Dict[str, Any]]:
89
+ action_value = action.action if isinstance(action, ActionModel) else int(action)
90
+ observation, raw_reward, terminated, truncated, info = self._env.step(action_value)
91
+ reward = RewardModel(
92
+ raw=float(raw_reward),
93
+ normalized=float(info.get("normalized_reward", 0.0)),
94
+ components=dict(info.get("reward_components", {})),
95
+ )
96
+ done = bool(terminated or truncated)
97
+ return ObservationModel.from_vector(observation.tolist()), reward, done, dict(info)
98
+
99
+ def state(self) -> StateResponse:
100
+ raw_state = self._env.state()
101
+ observation_vector = list(raw_state.get("observation", [0.0, 0.0, 0.0, 0.0]))
102
+ observation = ObservationModel.from_vector(observation_vector)
103
+ return StateResponse(
104
+ current_code=str(raw_state.get("current_code", "")),
105
+ episode_steps=int(raw_state.get("episode_steps", 0)),
106
+ max_steps=int(raw_state.get("max_steps", RefactorEnv.MAX_STEPS)),
107
+ complexity=float(raw_state.get("complexity", 0.0)),
108
+ last_runtime=float(raw_state.get("last_runtime", 0.0)),
109
+ last_error=bool(raw_state.get("last_error", False)),
110
+ sample_id=raw_state.get("sample_id"),
111
+ language=raw_state.get("language"),
112
+ task_id=self._task_id,
113
+ observation=observation,
114
+ observation_vector=observation.to_vector(),
115
+ action_meanings=dict(raw_state.get("action_meanings", {})),
116
+ )
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.109.0
2
+ uvicorn[standard]>=0.27.0
3
+ numpy>=1.26
4
+ gymnasium
5
+ stable-baselines3
6
+ radon>=6.0.1
7
+ openai>=1.0.0
8
+ openenv>=0.1.13
9
+ requests>=2.31.0
10
+ pydantic>=2.0.0
11
+ typing_extensions>=4.0.0
server.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACRE OpenEnv HTTP server.
3
+
4
+ Endpoints (all required by OpenEnv spec):
5
+ GET / — health check (must return HTTP 200)
6
+ POST /reset — reset environment, returns observation + info
7
+ POST /step — take one step, returns obs/reward/done/info
8
+ GET /state — full current state snapshot
9
+ GET /tasks — list all tasks with initial code
10
+ POST /tasks/{task_id}/grade — grade code for a specific task
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import difflib
15
+ import os
16
+ import re
17
+ import json
18
+ from typing import Optional
19
+
20
+ import uvicorn
21
+ import numpy as np
22
+ from fastapi import FastAPI, HTTPException
23
+ from fastapi.middleware.cors import CORSMiddleware
24
+ from fastapi.responses import HTMLResponse
25
+ from openai import OpenAI
26
+
27
+ try:
28
+ from stable_baselines3 import PPO
29
+ except Exception:
30
+ PPO = None # type: ignore[assignment]
31
+
32
+ from acre.tasks.task_registry import TaskRegistry
33
+ from models import (
34
+ ActionModel,
35
+ CompatibilityHealthResponse,
36
+ GradeRequest,
37
+ GradeResponse,
38
+ HealthResponse,
39
+ OptimizationStep,
40
+ OptimizeRequest,
41
+ OptimizeResponse,
42
+ ResetRequest,
43
+ ResetResponse,
44
+ StateResponse,
45
+ StepRequest,
46
+ StepResponse,
47
+ TaskInfo,
48
+ TasksResponse,
49
+ )
50
+ from openenv_interface import OpenEnvRefactorEnv
51
+
52
+ DEFAULT_API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
53
+ DEFAULT_MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
54
+ DEFAULT_RL_MODEL_PATH = os.getenv("RL_MODEL_PATH", "acre_agent.zip")
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # App setup
58
+ # ---------------------------------------------------------------------------
59
+
60
+ app = FastAPI(
61
+ title="ACRE — Autonomous Code Refactoring Environment",
62
+ description="OpenEnv-compatible RL environment for Python code refactoring.",
63
+ version="1.0.0",
64
+ )
65
+
66
+ app.add_middleware(
67
+ CORSMiddleware,
68
+ allow_origins=["*"],
69
+ allow_methods=["*"],
70
+ allow_headers=["*"],
71
+ )
72
+
73
+ # Global singletons
74
+ registry = TaskRegistry()
75
+ _env: Optional[OpenEnvRefactorEnv] = None
76
+ _rl_model_cache: dict[str, object] = {}
77
+
78
+
79
+ def get_env() -> OpenEnvRefactorEnv:
80
+ global _env
81
+ if _env is None:
82
+ _env = OpenEnvRefactorEnv(registry=registry)
83
+ return _env
84
+
85
+
86
+ def _state_response() -> StateResponse:
87
+ return get_env().state()
88
+
89
+
90
+ def _choose_action_heuristic(code: str, task_id: Optional[str]) -> int:
91
+ has_generic = re.search(r"\b(x|tmp|i)\b", code) is not None
92
+ has_if_false = re.search(r"\bif\s+False\b", code) is not None
93
+ has_if_true = re.search(r"\bif\s+True\b", code) is not None
94
+ has_append_loop = ".append(" in code and "for " in code
95
+ has_double_not = "not not" in code
96
+ has_add_call = "add(" in code
97
+
98
+ if task_id == "rename_variables":
99
+ if has_generic:
100
+ return 0
101
+ if has_if_false or "unused" in code:
102
+ return 1
103
+ if has_append_loop:
104
+ return 2
105
+ if has_if_true or has_double_not:
106
+ return 3
107
+ return 4
108
+
109
+ if task_id == "remove_dead_code":
110
+ if has_if_false or "unused" in code:
111
+ return 1
112
+ if has_append_loop:
113
+ return 2
114
+ if has_if_true or has_double_not:
115
+ return 3
116
+ if has_generic:
117
+ return 0
118
+ return 4
119
+
120
+ if has_generic:
121
+ return 0
122
+ if has_append_loop:
123
+ return 2
124
+ if has_if_false or has_if_true or has_double_not:
125
+ return 3
126
+ if has_add_call:
127
+ return 4
128
+ return 1
129
+
130
+
131
+ def _choose_action_llm(
132
+ *,
133
+ code: str,
134
+ task_id: Optional[str],
135
+ step_index: int,
136
+ max_steps: int,
137
+ api_base_url: str,
138
+ model_name: str,
139
+ api_token: str,
140
+ ) -> tuple[int, str, str]:
141
+ if not api_token.strip():
142
+ return _choose_action_heuristic(code, task_id), "empty token -> heuristic", "heuristic"
143
+
144
+ client = OpenAI(base_url=api_base_url, api_key=api_token)
145
+ messages = [
146
+ {
147
+ "role": "system",
148
+ "content": (
149
+ "You are a code-refactoring action selector. Return ONLY compact JSON: "
150
+ '{"action": <0-4>, "reason": "..."}.\n'
151
+ "Actions: 0=rename_variable,1=remove_dead_code,2=simplify_loop,3=optimize_condition,4=inline_function"
152
+ ),
153
+ },
154
+ {
155
+ "role": "user",
156
+ "content": (
157
+ f"task_id={task_id or 'auto'}\n"
158
+ f"step={step_index}/{max_steps}\n"
159
+ "Current code:\n"
160
+ f"```python\n{code}\n```"
161
+ ),
162
+ },
163
+ ]
164
+ try:
165
+ resp = client.chat.completions.create(
166
+ model=model_name,
167
+ messages=messages,
168
+ temperature=0.0,
169
+ max_tokens=120,
170
+ )
171
+ raw = (resp.choices[0].message.content or "").strip()
172
+ m = re.search(r"\{.*\}", raw, flags=re.DOTALL)
173
+ blob = m.group(0) if m else raw
174
+ parsed = json.loads(blob)
175
+ action = int(parsed.get("action", -1))
176
+ reason = str(parsed.get("reason", "llm-selected action"))
177
+ if 0 <= action <= 4:
178
+ return action, reason, "llm"
179
+ except Exception as exc:
180
+ return _choose_action_heuristic(code, task_id), f"llm error -> heuristic: {exc}", "heuristic"
181
+
182
+ return _choose_action_heuristic(code, task_id), "invalid llm output -> heuristic", "heuristic"
183
+
184
+
185
+ def _choose_action_rl(observation: list[float], model_path: str) -> tuple[Optional[int], str, str]:
186
+ if PPO is None:
187
+ return None, "stable-baselines3 unavailable", "rl"
188
+ if not os.path.exists(model_path):
189
+ return None, f"rl model not found: {model_path}", "rl"
190
+
191
+ try:
192
+ model = _rl_model_cache.get(model_path)
193
+ if model is None:
194
+ model = PPO.load(model_path)
195
+ _rl_model_cache[model_path] = model
196
+
197
+ obs = np.asarray(observation, dtype=np.float32)
198
+ action, _ = model.predict(obs, deterministic=True)
199
+ action_i = int(action)
200
+ if 0 <= action_i <= 4:
201
+ return action_i, "rl policy action", "rl"
202
+ return None, f"invalid rl action: {action_i}", "rl"
203
+ except Exception as exc:
204
+ return None, f"rl failure: {exc}", "rl"
205
+
206
+
207
+ def _demo_html() -> str:
208
+ return """<!doctype html>
209
+ <html lang=\"en\">
210
+ <head>
211
+ <meta charset=\"utf-8\" />
212
+ <meta name=\"viewport\" content=\"width=device-width, initial-scale=1\" />
213
+ <title>ACRE Refactor Demo</title>
214
+ <style>
215
+ @import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@400;600;700&display=swap');
216
+ :root {
217
+ --bg0: #0b1f2a;
218
+ --bg1: #14344a;
219
+ --ink: #eaf7ff;
220
+ --muted: #a7c8db;
221
+ --brand: #1ec28b;
222
+ --warn: #ffcb47;
223
+ --panel: rgba(8, 24, 36, 0.72);
224
+ --stroke: rgba(140, 197, 225, 0.35);
225
+ }
226
+ * { box-sizing: border-box; }
227
+ body {
228
+ margin: 0;
229
+ color: var(--ink);
230
+ font-family: 'Space Grotesk', sans-serif;
231
+ background:
232
+ radial-gradient(circle at 12% 18%, rgba(30, 194, 139, 0.28), transparent 35%),
233
+ radial-gradient(circle at 88% 8%, rgba(255, 203, 71, 0.22), transparent 30%),
234
+ linear-gradient(150deg, var(--bg0), var(--bg1));
235
+ min-height: 100vh;
236
+ }
237
+ .wrap {
238
+ max-width: 1200px;
239
+ margin: 0 auto;
240
+ padding: 28px 20px 40px;
241
+ }
242
+ h1 {
243
+ margin: 0 0 6px;
244
+ font-size: clamp(1.6rem, 2vw + 1rem, 2.6rem);
245
+ letter-spacing: 0.2px;
246
+ }
247
+ .sub { margin: 0 0 20px; color: var(--muted); }
248
+ .grid {
249
+ display: grid;
250
+ grid-template-columns: 1fr;
251
+ gap: 16px;
252
+ }
253
+ .panel {
254
+ border: 1px solid var(--stroke);
255
+ border-radius: 14px;
256
+ background: var(--panel);
257
+ backdrop-filter: blur(4px);
258
+ padding: 14px;
259
+ }
260
+ .controls {
261
+ display: grid;
262
+ grid-template-columns: 1fr 1fr;
263
+ gap: 8px;
264
+ margin-bottom: 10px;
265
+ }
266
+ textarea, pre {
267
+ width: 100%;
268
+ min-height: 260px;
269
+ border: 1px solid var(--stroke);
270
+ border-radius: 10px;
271
+ padding: 12px;
272
+ background: rgba(1, 13, 24, 0.82);
273
+ color: #dcf4ff;
274
+ font-family: Consolas, 'Courier New', monospace;
275
+ font-size: 13px;
276
+ line-height: 1.4;
277
+ overflow: auto;
278
+ white-space: pre;
279
+ }
280
+ button, select {
281
+ border: 1px solid var(--stroke);
282
+ border-radius: 10px;
283
+ padding: 10px 12px;
284
+ background: rgba(11, 36, 52, 0.9);
285
+ color: var(--ink);
286
+ font-weight: 600;
287
+ }
288
+ button.primary {
289
+ background: linear-gradient(120deg, #19a7ff, #1ec28b);
290
+ color: #032235;
291
+ border: none;
292
+ }
293
+ .cols {
294
+ display: grid;
295
+ grid-template-columns: 1fr;
296
+ gap: 14px;
297
+ }
298
+ .meta {
299
+ color: var(--muted);
300
+ font-size: 0.92rem;
301
+ margin-top: 8px;
302
+ }
303
+ .badge {
304
+ color: #082b22;
305
+ background: var(--brand);
306
+ border-radius: 999px;
307
+ padding: 2px 9px;
308
+ font-size: 12px;
309
+ font-weight: 700;
310
+ }
311
+ .warn {
312
+ color: #2a1c00;
313
+ background: var(--warn);
314
+ }
315
+ @media (min-width: 900px) {
316
+ .cols { grid-template-columns: 1fr 1fr; }
317
+ }
318
+ </style>
319
+ </head>
320
+ <body>
321
+ <div class=\"wrap\">
322
+ <h1>ACRE Live Refactor Arena</h1>
323
+ <p class=\"sub\">Paste old code, run the agent, and compare before and after with a full diff and step-by-step rewards.</p>
324
+
325
+ <div class=\"panel\">
326
+ <div class=\"controls\">
327
+ <button onclick=\"loadExample(1)\">Load Example 1</button>
328
+ <button onclick=\"loadExample(2)\">Load Example 2</button>
329
+ <select id=\"task\">
330
+ <option value=\"\">Auto strategy</option>
331
+ <option value=\"rename_variables\">rename_variables</option>
332
+ <option value=\"remove_dead_code\">remove_dead_code</option>
333
+ <option value=\"full_refactor\">full_refactor</option>
334
+ </select>
335
+ <button class=\"primary\" onclick=\"runOptimize()\">Run Optimization</button>
336
+ </div>
337
+ <div class=\"controls\" style=\"margin-bottom: 10px;\">
338
+ <select id=\"mode\">
339
+ <option value=\"rl_then_llm\">RL First -> LLM Fallback</option>
340
+ <option value=\"heuristic\">Heuristic Agent (no API key)</option>
341
+ <option value=\"llm\">LLM Agent (OpenAI-compatible API)</option>
342
+ </select>
343
+ <input id=\"rlModelPath\" placeholder=\"RL model path\" value=\"acre_agent.zip\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
344
+ <input id=\"baseUrl\" placeholder=\"API base URL (optional)\" value=\"https://api.openai.com/v1\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
345
+ <input id=\"modelName\" placeholder=\"Model name (optional)\" value=\"gpt-4o-mini\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
346
+ <input id=\"apiToken\" type=\"password\" placeholder=\"Paste API token here for LLM mode\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
347
+ </div>
348
+ <div class=\"controls\" style=\"margin-bottom: 10px;\">
349
+ <label style=\"display:flex;align-items:center;gap:8px;padding:8px 10px;border:1px solid var(--stroke);border-radius:10px;\">
350
+ <input id=\"autoSuggest\" type=\"checkbox\" />
351
+ Auto suggest after typing pause
352
+ </label>
353
+ </div>
354
+ <textarea id=\"input\" spellcheck=\"false\" placeholder=\"Paste your Python code here...\"></textarea>
355
+ <p class=\"meta\" id=\"status\">Status: ready</p>
356
+ </div>
357
+
358
+ <div class=\"cols\" style=\"margin-top: 14px\">
359
+ <div class=\"panel\">
360
+ <h3>Original Code</h3>
361
+ <pre id=\"original\"></pre>
362
+ </div>
363
+ <div class=\"panel\">
364
+ <h3>Optimized Code</h3>
365
+ <pre id=\"optimized\"></pre>
366
+ </div>
367
+ </div>
368
+
369
+ <div class=\"panel\" style=\"margin-top: 14px\">
370
+ <h3>Diff</h3>
371
+ <pre id=\"diff\"></pre>
372
+ </div>
373
+
374
+ <div class=\"panel\" style=\"margin-top: 14px\">
375
+ <h3>Step Logs</h3>
376
+ <pre id=\"steps\"></pre>
377
+ </div>
378
+ </div>
379
+
380
+ <script>
381
+ const EX1 = `def compute(x, y, tmp):\n tmp = x + y\n x = tmp * 2\n result = x\n return result\n`;
382
+ const EX2 = `def add(p, q):\n return p + q\n\ndef compute(x, data, tmp):\n result = []\n for item in data:\n result.append(item * 2)\n if False:\n y = 999\n if True:\n val = add(x, tmp)\n unused = 0\n flag = not not True\n return val\n print(\"dead\")\n`;
383
+ let autoTimer = null;
384
+
385
+ function loadExample(i) {
386
+ document.getElementById('input').value = i === 1 ? EX1 : EX2;
387
+ document.getElementById('status').textContent = `Status: loaded example ${i}`;
388
+ }
389
+
390
+ async function runOptimize() {
391
+ const code = document.getElementById('input').value;
392
+ const task = document.getElementById('task').value || null;
393
+ const mode = document.getElementById('mode').value;
394
+ const useRl = mode === 'rl_then_llm';
395
+ const useLlm = mode === 'llm' || mode === 'rl_then_llm';
396
+ const fallbackToLlm = mode === 'rl_then_llm';
397
+ const rlModelPath = document.getElementById('rlModelPath').value || null;
398
+ const apiToken = document.getElementById('apiToken').value || null;
399
+ const apiBaseUrl = document.getElementById('baseUrl').value || null;
400
+ const modelName = document.getElementById('modelName').value || null;
401
+ if (!code.trim()) {
402
+ document.getElementById('status').innerHTML = 'Status: <span class=\"badge warn\">please paste code first</span>';
403
+ return;
404
+ }
405
+ if (mode === 'llm' && (!apiToken || !apiToken.trim())) {
406
+ document.getElementById('status').innerHTML = 'Status: <span class=\"badge warn\">paste API token for LLM mode</span>';
407
+ return;
408
+ }
409
+
410
+ document.getElementById('status').textContent = 'Status: running optimization...';
411
+ try {
412
+ const res = await fetch('/optimize', {
413
+ method: 'POST',
414
+ headers: {'Content-Type': 'application/json'},
415
+ body: JSON.stringify({
416
+ code,
417
+ task_id: task,
418
+ max_steps: 5,
419
+ use_rl: useRl,
420
+ use_llm: useLlm,
421
+ fallback_to_llm: fallbackToLlm,
422
+ rl_model_path: rlModelPath,
423
+ api_base_url: apiBaseUrl,
424
+ model_name: modelName,
425
+ api_token: apiToken,
426
+ })
427
+ });
428
+ const data = await res.json();
429
+ if (!res.ok) {
430
+ throw new Error(data.detail || 'request failed');
431
+ }
432
+
433
+ document.getElementById('original').textContent = data.original_code;
434
+ document.getElementById('optimized').textContent = data.optimized_code;
435
+ document.getElementById('diff').textContent = data.diff || '(no diff)';
436
+ document.getElementById('steps').textContent = JSON.stringify(data.steps, null, 2);
437
+
438
+ const scoreText = data.task_score === null ? 'n/a' : data.task_score;
439
+ document.getElementById('status').innerHTML = `Status: <span class=\"badge\">done</span> cumulative_reward=${data.cumulative_reward.toFixed(2)} task_score=${scoreText}`;
440
+ } catch (err) {
441
+ document.getElementById('status').innerHTML = `Status: <span class=\"badge warn\">error</span> ${err.message}`;
442
+ }
443
+ }
444
+
445
+ loadExample(1);
446
+ document.getElementById('input').addEventListener('input', () => {
447
+ if (!document.getElementById('autoSuggest').checked) {
448
+ return;
449
+ }
450
+ if (autoTimer) {
451
+ clearTimeout(autoTimer);
452
+ }
453
+ autoTimer = setTimeout(() => {
454
+ runOptimize();
455
+ }, 1200);
456
+ });
457
+ </script>
458
+ </body>
459
+ </html>"""
460
+
461
+
462
+ # ---------------------------------------------------------------------------
463
+ # Routes
464
+ # ---------------------------------------------------------------------------
465
+
466
+ @app.get("/", response_model=HealthResponse)
467
+ def health() -> HealthResponse:
468
+ """Health check — OpenEnv pings this URL to verify the Space is live."""
469
+ return HealthResponse(status="ok", env="ACRE", version="1.0.0")
470
+
471
+
472
+ @app.get("/health", response_model=CompatibilityHealthResponse)
473
+ def health_compat() -> CompatibilityHealthResponse:
474
+ """Compatibility health route used by some OpenEnv reference environments."""
475
+ return CompatibilityHealthResponse(status="healthy", service="acre-env")
476
+
477
+
478
+ @app.get("/demo", response_class=HTMLResponse)
479
+ def demo_ui() -> HTMLResponse:
480
+ """Simple UI to compare original and optimized code side-by-side."""
481
+ return HTMLResponse(content=_demo_html())
482
+
483
+
484
+ @app.post("/reset", response_model=ResetResponse)
485
+ def reset(req: ResetRequest = ResetRequest()) -> ResetResponse:
486
+ """Reset the environment. Optionally load a task's initial code."""
487
+ env = get_env()
488
+ try:
489
+ obs = env.reset(seed=req.seed, task_id=req.task_id, code=req.code)
490
+ except ValueError as exc:
491
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
492
+ return ResetResponse(
493
+ observation=obs,
494
+ observation_vector=obs.to_vector(),
495
+ info=env.last_reset_info,
496
+ task_id=req.task_id,
497
+ state=_state_response(),
498
+ )
499
+
500
+
501
+ @app.post("/step", response_model=StepResponse)
502
+ def step(req: StepRequest) -> StepResponse:
503
+ """Take one refactoring step."""
504
+ env = get_env()
505
+ if not (0 <= req.action <= 4):
506
+ raise HTTPException(status_code=400, detail="action must be 0–4")
507
+
508
+ obs, reward, done, info = env.step(req.action)
509
+ action_name = str(info.get("action_name", env.action_meanings.get(req.action, "unknown")))
510
+
511
+ return StepResponse(
512
+ action=ActionModel(action=req.action, action_name=action_name),
513
+ observation=obs,
514
+ observation_vector=obs.to_vector(),
515
+ reward=reward,
516
+ done=done,
517
+ terminated=done,
518
+ truncated=False,
519
+ info=info,
520
+ state=_state_response(),
521
+ )
522
+
523
+
524
+ @app.get("/state", response_model=StateResponse)
525
+ def state() -> StateResponse:
526
+ """Return full current environment state (OpenEnv spec requirement)."""
527
+ return _state_response()
528
+
529
+
530
+ @app.get("/tasks", response_model=TasksResponse)
531
+ def list_tasks() -> TasksResponse:
532
+ """Enumerate all tasks (easy → medium → hard)."""
533
+ return TasksResponse(tasks=[TaskInfo.model_validate(t) for t in registry.list_tasks()])
534
+
535
+
536
+ @app.post("/tasks/{task_id}/grade", response_model=GradeResponse)
537
+ def grade(task_id: str, req: GradeRequest) -> GradeResponse:
538
+ """Grade submitted code against a task's grader (returns score 0.0–1.0)."""
539
+ task = registry.get_task(task_id)
540
+ if task is None:
541
+ raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found")
542
+ score = task.grade(req.code)
543
+ return GradeResponse(
544
+ task_id=task_id,
545
+ score=round(score, 4),
546
+ passed=score >= 0.8,
547
+ )
548
+
549
+
550
+ @app.post("/optimize", response_model=OptimizeResponse)
551
+ def optimize(req: OptimizeRequest) -> OptimizeResponse:
552
+ """Run a full optimization episode and return code comparison artifacts."""
553
+ code = req.code.strip("\n")
554
+ if not code.strip():
555
+ raise HTTPException(status_code=400, detail="code must be non-empty")
556
+
557
+ env = get_env()
558
+ try:
559
+ env.reset(task_id=req.task_id, code=code)
560
+ except ValueError as exc:
561
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
562
+
563
+ steps: list[OptimizationStep] = []
564
+ cumulative_reward = 0.0
565
+
566
+ for step_idx in range(1, req.max_steps + 1):
567
+ state_now = env.state()
568
+ current_code = state_now.current_code
569
+ obs_list = [float(x) for x in state_now.observation_vector]
570
+
571
+ action: int
572
+ reason: str
573
+ source: str
574
+
575
+ if req.use_rl:
576
+ rl_action, rl_reason, rl_source = _choose_action_rl(
577
+ observation=obs_list,
578
+ model_path=req.rl_model_path or DEFAULT_RL_MODEL_PATH,
579
+ )
580
+ if rl_action is not None:
581
+ action, reason, source = rl_action, rl_reason, rl_source
582
+ elif req.fallback_to_llm and req.use_llm:
583
+ action, reason, source = _choose_action_llm(
584
+ code=current_code,
585
+ task_id=req.task_id,
586
+ step_index=step_idx,
587
+ max_steps=req.max_steps,
588
+ api_base_url=req.api_base_url or DEFAULT_API_BASE_URL,
589
+ model_name=req.model_name or DEFAULT_MODEL_NAME,
590
+ api_token=req.api_token or "",
591
+ )
592
+ reason = f"{rl_reason}; {reason}"
593
+ else:
594
+ action = _choose_action_heuristic(current_code, req.task_id)
595
+ reason = f"{rl_reason}; heuristic fallback"
596
+ source = "heuristic"
597
+ elif req.use_llm:
598
+ action, reason, source = _choose_action_llm(
599
+ code=current_code,
600
+ task_id=req.task_id,
601
+ step_index=step_idx,
602
+ max_steps=req.max_steps,
603
+ api_base_url=req.api_base_url or DEFAULT_API_BASE_URL,
604
+ model_name=req.model_name or DEFAULT_MODEL_NAME,
605
+ api_token=req.api_token or "",
606
+ )
607
+ else:
608
+ action = _choose_action_heuristic(current_code, req.task_id)
609
+ reason = "heuristic policy"
610
+ source = "heuristic"
611
+
612
+ _, reward, done, info = env.step(action)
613
+ state_now = env.state()
614
+
615
+ cumulative_reward += float(reward.raw)
616
+ steps.append(
617
+ OptimizationStep(
618
+ step=step_idx,
619
+ action=action,
620
+ action_name=info.get("action_name", "unknown"),
621
+ reason=reason,
622
+ source=source,
623
+ reward=float(reward.raw),
624
+ normalized_reward=float(reward.normalized),
625
+ changed=bool(info.get("changed", False)),
626
+ complexity=float(state_now.complexity),
627
+ )
628
+ )
629
+
630
+ if done:
631
+ break
632
+
633
+ final_code = str(env.state().current_code)
634
+ diff_lines = difflib.unified_diff(
635
+ code.splitlines(),
636
+ final_code.splitlines(),
637
+ fromfile="original.py",
638
+ tofile="optimized.py",
639
+ lineterm="",
640
+ )
641
+ diff_text = "\n".join(diff_lines)
642
+
643
+ task_score: Optional[float] = None
644
+ if req.task_id:
645
+ task = registry.get_task(req.task_id)
646
+ if task is None:
647
+ raise HTTPException(status_code=404, detail=f"Task '{req.task_id}' not found")
648
+ task_score = round(task.grade(final_code), 4)
649
+
650
+ return OptimizeResponse(
651
+ original_code=code,
652
+ optimized_code=final_code,
653
+ diff=diff_text,
654
+ steps=steps,
655
+ cumulative_reward=round(cumulative_reward, 4),
656
+ task_id=req.task_id,
657
+ task_score=task_score,
658
+ )
659
+
660
+
661
+ # ---------------------------------------------------------------------------
662
+ # Entry point
663
+ # ---------------------------------------------------------------------------
664
+
665
+ if __name__ == "__main__":
666
+ port = int(os.getenv("PORT", 7860))
667
+ uvicorn.run(app, host="0.0.0.0", port=port)
validate.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACRE pre-submission validator.
3
+
4
+ Checks the repository against the submission checklist and, when a server URL is
5
+ available, probes the HTTP API as well.
6
+
7
+ Run:
8
+ python validate.py --url http://localhost:7860
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import ast
14
+ import re
15
+ import sys
16
+ from typing import Any, Tuple
17
+
18
+ try:
19
+ import requests
20
+ except ImportError:
21
+ print("[ERROR] requests is required. Run: pip install requests")
22
+ sys.exit(1)
23
+
24
+ PASS = "\033[92m[PASS]\033[0m"
25
+ FAIL = "\033[91m[FAIL]\033[0m"
26
+
27
+
28
+ def check(label: str, ok: bool, detail: str = "") -> bool:
29
+ status = PASS if ok else FAIL
30
+ message = f" {status} {label}"
31
+ if detail:
32
+ message += f" - {detail}"
33
+ print(message)
34
+ return ok
35
+
36
+
37
+ def get(url: str, path: str, timeout: int = 15) -> Tuple[bool, Any]:
38
+ try:
39
+ response = requests.get(f"{url}{path}", timeout=timeout)
40
+ response.raise_for_status()
41
+ return True, response.json()
42
+ except Exception as exc:
43
+ return False, str(exc)
44
+
45
+
46
+ def post(url: str, path: str, payload: dict, timeout: int = 15) -> Tuple[bool, Any]:
47
+ try:
48
+ response = requests.post(f"{url}{path}", json=payload, timeout=timeout)
49
+ response.raise_for_status()
50
+ return True, response.json()
51
+ except Exception as exc:
52
+ return False, str(exc)
53
+
54
+
55
+ def read_text(path: str) -> str:
56
+ with open(path, encoding="utf-8") as handle:
57
+ return handle.read()
58
+
59
+
60
+ def run_validation(base_url: str) -> int:
61
+ failures = 0
62
+
63
+ print("\n" + "=" * 60)
64
+ print(" ACRE Pre-Submission Validator")
65
+ print("=" * 60)
66
+ print(f" Target: {base_url}\n")
67
+
68
+ print("1. Static repository checks")
69
+ try:
70
+ interface_src = read_text("openenv_interface.py")
71
+ tree = ast.parse(interface_src)
72
+ classes = {node.name: node for node in tree.body if isinstance(node, ast.ClassDef)}
73
+ env_cls = classes.get("OpenEnvRefactorEnv")
74
+ failures += 0 if check("openenv_interface.py exists", True) else 1
75
+ failures += 0 if check("OpenEnvRefactorEnv is defined", env_cls is not None) else 1
76
+ if env_cls is not None:
77
+ methods = {node.name for node in env_cls.body if isinstance(node, ast.FunctionDef)}
78
+ for method_name in ["reset", "step", "state"]:
79
+ failures += 0 if check(
80
+ f"OpenEnvRefactorEnv implements {method_name}()",
81
+ method_name in methods,
82
+ ) else 1
83
+ except FileNotFoundError:
84
+ failures += 1
85
+ check("openenv_interface.py exists", False, "file not found")
86
+
87
+ try:
88
+ models_src = read_text("models.py")
89
+ for name in ["ObservationModel", "ActionModel", "RewardModel"]:
90
+ failures += 0 if check(
91
+ f"{name} is defined in models.py",
92
+ f"class {name}" in models_src,
93
+ ) else 1
94
+ except FileNotFoundError:
95
+ failures += 1
96
+ check("models.py exists", False, "file not found")
97
+
98
+ print("\n2. Health check (GET /)")
99
+ ok, data = get(base_url, "/")
100
+ failures += 0 if check("GET / returns HTTP 200", ok) else 1
101
+ if ok:
102
+ failures += 0 if check(
103
+ "Response has status field",
104
+ isinstance(data, dict) and "status" in data,
105
+ str(data),
106
+ ) else 1
107
+
108
+ print("\n3. Tasks (GET /tasks)")
109
+ ok, data = get(base_url, "/tasks")
110
+ failures += 0 if check("GET /tasks returns 200", ok) else 1
111
+ if ok:
112
+ tasks = data.get("tasks", []) if isinstance(data, dict) else []
113
+ failures += 0 if check("At least 3 tasks defined", len(tasks) >= 3, f"found {len(tasks)}") else 1
114
+ difficulties = [t.get("difficulty", "") for t in tasks]
115
+ for diff in ["easy", "medium", "hard"]:
116
+ failures += 0 if check(f"Task with difficulty '{diff}' exists", diff in difficulties) else 1
117
+ for task in tasks:
118
+ failures += 0 if check(
119
+ f"Task '{task.get('id')}' has initial_code",
120
+ bool(task.get("initial_code")),
121
+ ) else 1
122
+
123
+ print("\n4. Reset (POST /reset)")
124
+ ok, data = post(base_url, "/reset", {})
125
+ failures += 0 if check("POST /reset returns 200", ok) else 1
126
+ if ok:
127
+ observation = data.get("observation", {})
128
+ failures += 0 if check("Response has observation field", isinstance(observation, dict)) else 1
129
+ failures += 0 if check(
130
+ "Observation is typed with 4 fields",
131
+ {"code_length", "complexity_score", "runtime_s", "error_flag"}.issubset(observation),
132
+ str(observation),
133
+ ) else 1
134
+
135
+ ok, _ = post(base_url, "/reset", {"task_id": "rename_variables"})
136
+ failures += 0 if check("POST /reset with task_id works", ok) else 1
137
+
138
+ print("\n5. State (GET /state)")
139
+ ok, data = get(base_url, "/state")
140
+ failures += 0 if check("GET /state returns 200", ok) else 1
141
+ if ok:
142
+ required_keys = [
143
+ "current_code",
144
+ "episode_steps",
145
+ "max_steps",
146
+ "complexity",
147
+ "observation",
148
+ "observation_vector",
149
+ "action_meanings",
150
+ ]
151
+ for key in required_keys:
152
+ failures += 0 if check(f"State has '{key}' field", key in data) else 1
153
+
154
+ print("\n6. Step (POST /step)")
155
+ post(base_url, "/reset", {"task_id": "rename_variables"})
156
+ for action in range(5):
157
+ ok, data = post(base_url, "/step", {"action": action})
158
+ failures += 0 if check(
159
+ f"Action {action} executes without error",
160
+ ok and isinstance(data, dict) and "reward" in data and "done" in data,
161
+ ) else 1
162
+ if ok:
163
+ reward_payload = data.get("reward", {})
164
+ norm = reward_payload.get("normalized", -1)
165
+ failures += 0 if check(
166
+ f"Action {action} returns typed reward payload",
167
+ {"raw", "normalized", "components"}.issubset(reward_payload),
168
+ str(reward_payload),
169
+ ) else 1
170
+ failures += 0 if check(
171
+ f"Action {action} normalized_reward in [0,1]",
172
+ isinstance(norm, (int, float)) and 0.0 <= float(norm) <= 1.0,
173
+ f"got {norm}",
174
+ ) else 1
175
+ if data.get("done"):
176
+ break
177
+
178
+ ok, data = post(base_url, "/step", {"action": 99})
179
+ check("Invalid action returns error (not crash)", not ok or "detail" in str(data), "(expected 4xx)")
180
+
181
+ print("\n7. Task graders (POST /tasks/{id}/grade)")
182
+ for task_id in ["rename_variables", "remove_dead_code", "full_refactor"]:
183
+ ok, data = post(base_url, f"/tasks/{task_id}/grade", {"code": "def f(): pass"})
184
+ failures += 0 if check(f"Grade endpoint for '{task_id}' works", ok) else 1
185
+ if ok:
186
+ score = data.get("score", -1)
187
+ failures += 0 if check(
188
+ f"Score for '{task_id}' in [0.0, 1.0]",
189
+ isinstance(score, (int, float)) and 0.0 <= float(score) <= 1.0,
190
+ f"got {score}",
191
+ ) else 1
192
+
193
+ print("\n8. openenv.yaml")
194
+ try:
195
+ openenv_yaml = read_text("openenv.yaml")
196
+ failures += 0 if check("openenv.yaml exists", True) else 1
197
+ for field in ["tasks:", "action_space:", "observation_space:", "reward:", "entrypoint:", "validation:"]:
198
+ failures += 0 if check(f"openenv.yaml has '{field}' section", field in openenv_yaml) else 1
199
+ except FileNotFoundError:
200
+ failures += 1
201
+ check("openenv.yaml exists", False, "file not found")
202
+
203
+ print("\n9. inference.py")
204
+ try:
205
+ inference_src = read_text("inference.py")
206
+ failures += 0 if check("inference.py exists", True) else 1
207
+ for marker in ['"event": "START"', '"event": "STEP"', '"event": "END"']:
208
+ failures += 0 if check(f"inference.py emits {marker}", marker in inference_src) else 1
209
+ failures += 0 if check(
210
+ "Uses OpenAI client",
211
+ "from openai import OpenAI" in inference_src,
212
+ ) else 1
213
+ for var in ["API_BASE_URL", "MODEL_NAME", "HF_TOKEN", "ENV_URL", "LOCAL_IMAGE_NAME"]:
214
+ failures += 0 if check(f"inference.py reads {var} from env", var in inference_src) else 1
215
+ failures += 0 if check(
216
+ "API_BASE_URL has a default",
217
+ 'os.getenv("API_BASE_URL", "https://api.openai.com/v1")' in inference_src,
218
+ ) else 1
219
+ failures += 0 if check(
220
+ "MODEL_NAME has a default",
221
+ 'os.getenv("MODEL_NAME", "gpt-4o-mini")' in inference_src,
222
+ ) else 1
223
+ failures += 0 if check(
224
+ "HF_TOKEN has no default",
225
+ re.search(r'HF_TOKEN\s*:\s*.*os\.getenv\("HF_TOKEN"\)', inference_src) is not None,
226
+ ) else 1
227
+ except FileNotFoundError:
228
+ failures += 1
229
+ check("inference.py exists", False, "file not found")
230
+
231
+ print("\n10. Dockerfile")
232
+ try:
233
+ dockerfile = read_text("Dockerfile")
234
+ failures += 0 if check("Dockerfile exists", True) else 1
235
+ failures += 0 if check("Exposes port 7860", "7860" in dockerfile) else 1
236
+ failures += 0 if check("Has CMD/ENTRYPOINT", "CMD" in dockerfile or "ENTRYPOINT" in dockerfile) else 1
237
+ failures += 0 if check("Does not set a default HF_TOKEN", "ENV HF_TOKEN" not in dockerfile) else 1
238
+ except FileNotFoundError:
239
+ failures += 1
240
+ check("Dockerfile exists", False, "file not found")
241
+
242
+ print("\n11. README / Hugging Face metadata")
243
+ try:
244
+ readme = read_text("README.md")
245
+ failures += 0 if check("README has docker SDK front matter", "sdk: docker" in readme) else 1
246
+ failures += 0 if check("README includes openenv tag", "openenv" in readme) else 1
247
+ for section in [
248
+ "Environment Overview and Motivation",
249
+ "Definitions of Action and Observation Spaces",
250
+ "Task Descriptions with Expected Difficulty Levels",
251
+ "Setup and Usage Instructions",
252
+ "Baseline Performance Scores",
253
+ ]:
254
+ failures += 0 if check(f"README includes '{section}'", section in readme) else 1
255
+ except FileNotFoundError:
256
+ failures += 1
257
+ check("README.md exists", False, "file not found")
258
+
259
+ print("\n" + "=" * 60)
260
+ if failures == 0:
261
+ print(f" {PASS} All checks passed. Repository is submission-ready.")
262
+ else:
263
+ print(f" {FAIL} {failures} check(s) failed. Fix before submitting.")
264
+ print("=" * 60 + "\n")
265
+
266
+ return failures
267
+
268
+
269
+ def main() -> None:
270
+ parser = argparse.ArgumentParser(description="ACRE pre-submission validator")
271
+ parser.add_argument(
272
+ "--url",
273
+ default="http://localhost:7860",
274
+ help="Base URL of the running ACRE server",
275
+ )
276
+ args = parser.parse_args()
277
+ sys.exit(run_validation(args.url))
278
+
279
+
280
+ if __name__ == "__main__":
281
+ main()