PRANAV05092003 commited on
Commit
8c9f7aa
·
1 Parent(s): 8d66fec

Added missing env module

Browse files
acre/env/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Environment package for ACRE.
3
+ """
4
+
5
+ from .refactor_env import RefactorEnv
6
+
7
+ __all__ = ["RefactorEnv"]
8
+
9
+ """Environment components for ACRE."""
10
+
11
+ from .refactor_env import RefactorEnv
12
+
13
+ __all__ = ["RefactorEnv"]
14
+
acre/env/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (283 Bytes). View file
 
acre/env/__pycache__/refactor_env.cpython-313.pyc ADDED
Binary file (14.9 kB). View file
 
acre/env/refactor_env.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import re
5
+ import time
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, Optional, Tuple
8
+
9
+ import multiprocessing as mp
10
+
11
+ import gymnasium as gym
12
+ import numpy as np
13
+
14
+ from acre.actions import transformations as tx
15
+ from acre.datasets.code_samples import CodeSample, CodeSampleDataset
16
+
17
+ try:
18
+ from radon.complexity import cc_visit
19
+ except Exception: # pragma: no cover
20
+ cc_visit = None # type: ignore[assignment]
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class _ExecResult:
25
+ exit_code: int
26
+ metrics: Dict[str, Any]
27
+ error: Optional[str] = None
28
+
29
+
30
+ _BANNED_PATTERNS: Tuple[str, ...] = (
31
+ r"\bimport\s+os\b",
32
+ r"\bimport\s+subprocess\b",
33
+ r"\bimport\s+pathlib\b",
34
+ r"\bimport\s+shutil\b",
35
+ r"\bopen\s*\(",
36
+ r"\bos\.(remove|unlink|rmdir|removedirs|rename|replace|system|popen)\b",
37
+ r"\bshutil\.(rmtree|move|copy|copytree)\b",
38
+ r"\bsubprocess\.(run|Popen|call|check_call|check_output)\b",
39
+ )
40
+
41
+
42
+ def _exec_worker(src: str, fname: str, out_q: "mp.Queue[dict]") -> None:
43
+ start = time.perf_counter()
44
+ try:
45
+ if any(re.search(p, src) for p in _BANNED_PATTERNS):
46
+ runtime_s = time.perf_counter() - start
47
+ out_q.put({"exit_code": 2, "runtime_s": float(runtime_s), "error": "forbidden_operation"})
48
+ return None
49
+
50
+ compiled = compile(src, fname, "exec")
51
+ exec_globals: Dict[str, Any] = {"__name__": "__main__"}
52
+ exec(compiled, exec_globals, None)
53
+ runtime_s = time.perf_counter() - start
54
+ out_q.put({"exit_code": 0, "runtime_s": float(runtime_s), "error": None})
55
+ return None
56
+ except Exception as exc:
57
+ runtime_s = time.perf_counter() - start
58
+ out_q.put({"exit_code": 1, "runtime_s": float(runtime_s), "error": str(exc)})
59
+ return None
60
+
61
+
62
+ class _InProcessExecutor:
63
+ """
64
+ Execute candidate code with a hard timeout to avoid hanging the server.
65
+
66
+ This is critical for deployment: the agent can easily generate `while True: ...`
67
+ or other long-running code. We treat timeout as an execution error.
68
+ """
69
+
70
+ def run(self, code: str, *, filename: str = "<acre>", timeout_s: float = 0.25) -> _ExecResult:
71
+ q: "mp.Queue[dict]" = mp.Queue(maxsize=1)
72
+ # NOTE: on Windows, Process target must be picklable (top-level function).
73
+ proc = mp.Process(target=_exec_worker, args=(code, filename, q), daemon=True)
74
+ proc.start()
75
+ proc.join(timeout=max(0.01, float(timeout_s)))
76
+
77
+ if proc.is_alive():
78
+ proc.terminate()
79
+ proc.join(timeout=0.1)
80
+ return _ExecResult(exit_code=124, metrics={"runtime_s": float(timeout_s)}, error="timeout")
81
+
82
+ payload: dict = {}
83
+ try:
84
+ payload = q.get_nowait()
85
+ except Exception:
86
+ payload = {"exit_code": 1, "runtime_s": 0.0, "error": "no result"}
87
+
88
+ return _ExecResult(
89
+ exit_code=int(payload.get("exit_code", 1)),
90
+ metrics={"runtime_s": float(payload.get("runtime_s", 0.0) or 0.0)},
91
+ error=payload.get("error"),
92
+ )
93
+
94
+
95
+ class RefactorEnv(gym.Env):
96
+ metadata = {"render_modes": []}
97
+
98
+ MAX_STEPS = 5
99
+
100
+ ACTION_MEANINGS: Dict[int, str] = {
101
+ 0: "rename_variable",
102
+ 1: "remove_dead_code",
103
+ 2: "simplify_loop",
104
+ 3: "optimize_condition",
105
+ 4: "inline_function",
106
+ }
107
+
108
+ def __init__(
109
+ self,
110
+ *,
111
+ dataset: Optional[CodeSampleDataset] = None,
112
+ seed: Optional[int] = None,
113
+ ) -> None:
114
+ super().__init__()
115
+ self.action_space = gym.spaces.Discrete(5)
116
+ self.observation_space = gym.spaces.Box(
117
+ low=np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32),
118
+ high=np.array([1e9, 1e9, 1e9, 1.0], dtype=np.float32),
119
+ dtype=np.float32,
120
+ )
121
+
122
+ self.dataset: CodeSampleDataset = dataset or CodeSampleDataset(
123
+ [
124
+ CodeSample(
125
+ id="default",
126
+ language="python",
127
+ code="def f(x):\n return x\n",
128
+ )
129
+ ]
130
+ )
131
+ self._np_random, _ = gym.utils.seeding.np_random(seed)
132
+
133
+ self.executor = _InProcessExecutor()
134
+
135
+ self._episode_steps = 0
136
+ self._sample: Optional[CodeSample] = None
137
+ self._code: str = ""
138
+ self._last_runtime_s: float = 0.0
139
+ self._last_error: bool = False
140
+ self._last_complexity: float = 0.0
141
+
142
+ def _compute_complexity(self, code: str) -> float:
143
+ if cc_visit is None:
144
+ return float(len(code.splitlines()))
145
+ try:
146
+ blocks = cc_visit(code)
147
+ if not blocks:
148
+ return 0.0
149
+ return float(sum(getattr(b, "complexity", 0) for b in blocks))
150
+ except Exception:
151
+ return float(len(code.splitlines()))
152
+
153
+ def _compute_runtime(self, code: str) -> Tuple[float, bool, bool]:
154
+ res = self.executor.run(code, filename="env_exec.py", timeout_s=0.25)
155
+ runtime_s = float(res.metrics.get("runtime_s", 0.0) or 0.0)
156
+ is_timeout = bool(res.exit_code == 124)
157
+ return runtime_s, bool(res.exit_code != 0), is_timeout
158
+
159
+ def _observation(self) -> np.ndarray:
160
+ return np.asarray(
161
+ [
162
+ float(len(self._code)),
163
+ float(self._last_complexity),
164
+ float(self._last_runtime_s),
165
+ float(int(self._last_error)),
166
+ ],
167
+ dtype=np.float32,
168
+ )
169
+
170
+ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
171
+ super().reset(seed=seed)
172
+ if seed is not None:
173
+ self._np_random, _ = gym.utils.seeding.np_random(seed)
174
+
175
+ samples = list(self.dataset)
176
+ if not samples:
177
+ samples = [CodeSample(id="empty", language="python", code="")]
178
+
179
+ idx = int(self._np_random.integers(0, len(samples)))
180
+ self._sample = samples[idx]
181
+ self._code = str(self._sample.code)
182
+ self._episode_steps = 0
183
+
184
+ self._last_complexity = self._compute_complexity(self._code)
185
+ self._last_runtime_s, self._last_error, _ = self._compute_runtime(self._code)
186
+
187
+ info = {
188
+ "sample_id": getattr(self._sample, "id", None),
189
+ "language": getattr(self._sample, "language", None),
190
+ "episode_steps": self._episode_steps,
191
+ }
192
+ return self._observation(), info
193
+
194
+ def step(self, action: int):
195
+ action_i = int(action)
196
+ if action_i not in self.ACTION_MEANINGS:
197
+ raise ValueError(f"Invalid action {action_i}; expected 0..4")
198
+
199
+ prev_complexity = float(self._last_complexity)
200
+ prev_runtime = float(self._last_runtime_s)
201
+ prev_error = bool(self._last_error)
202
+
203
+ original = self._code
204
+ if action_i == 0:
205
+ transform = tx.rename_variable(original)
206
+ elif action_i == 1:
207
+ transform = tx.remove_dead_code(original)
208
+ elif action_i == 2:
209
+ transform = tx.simplify_loop(original)
210
+ elif action_i == 3:
211
+ transform = tx.optimize_condition(original)
212
+ else:
213
+ transform = tx.inline_function(original)
214
+
215
+ self._code = transform.code
216
+ self._episode_steps += 1
217
+
218
+ self._last_complexity = self._compute_complexity(self._code)
219
+ self._last_runtime_s, self._last_error, is_timeout = self._compute_runtime(self._code)
220
+
221
+ complexity_gain = (prev_complexity - float(self._last_complexity)) / max(prev_complexity, 1.0)
222
+ runtime_gain = (prev_runtime - float(self._last_runtime_s)) / max(prev_runtime, 1e-6)
223
+ # Penalize execution errors strongly; timeouts even more strongly.
224
+ timeout_penalty = -2.0 if is_timeout else 0.0
225
+ error_penalty = -1.0 if self._last_error else 0.0
226
+ change_bonus = 0.05 if transform.changed else 0.0
227
+ no_change_penalty = -0.02 if not transform.changed else 0.0
228
+
229
+ raw_reward = float(
230
+ 2.0 * complexity_gain
231
+ + 0.25 * runtime_gain
232
+ + error_penalty
233
+ + timeout_penalty
234
+ + change_bonus
235
+ + no_change_penalty
236
+ )
237
+ if (not prev_error) and self._last_error:
238
+ raw_reward -= 0.5
239
+ if prev_error and (not self._last_error):
240
+ raw_reward += 0.5
241
+
242
+ # Normalize exactly as declared in openenv.yaml (clip to [0,1]).
243
+ normalized_reward = float((raw_reward + 32.0) / 52.0)
244
+ if normalized_reward < 0.0:
245
+ normalized_reward = 0.0
246
+ elif normalized_reward > 1.0:
247
+ normalized_reward = 1.0
248
+
249
+ terminated = bool(self._episode_steps >= int(self.MAX_STEPS))
250
+ truncated = False
251
+
252
+ info: Dict[str, Any] = {
253
+ "action_name": self.ACTION_MEANINGS[action_i],
254
+ "changed": bool(transform.changed),
255
+ "transform": dict(transform.metadata),
256
+ "reward_components": {
257
+ "complexity_gain": float(complexity_gain),
258
+ "runtime_gain": float(runtime_gain),
259
+ "error_penalty": float(error_penalty),
260
+ "timeout_penalty": float(timeout_penalty),
261
+ "change_bonus": float(change_bonus),
262
+ "no_change_penalty": float(no_change_penalty),
263
+ },
264
+ "normalized_reward": normalized_reward,
265
+ "episode_steps": int(self._episode_steps),
266
+ "timeout": bool(is_timeout),
267
+ }
268
+ return self._observation(), raw_reward, terminated, truncated, info
269
+
270
+ def state(self) -> Dict[str, Any]:
271
+ return {
272
+ "current_code": self._code,
273
+ "episode_steps": int(self._episode_steps),
274
+ "max_steps": int(self.MAX_STEPS),
275
+ "complexity": float(self._last_complexity),
276
+ "last_runtime": float(self._last_runtime_s),
277
+ "last_error": bool(self._last_error),
278
+ "sample_id": getattr(self._sample, "id", None) if self._sample is not None else None,
279
+ "language": getattr(self._sample, "language", None) if self._sample is not None else None,
280
+ "observation": self._observation().tolist(),
281
+ "action_meanings": dict(self.ACTION_MEANINGS),
282
+ }
283
+
284
+ def render(self) -> None:
285
+ return None
286
+
287
+ def close(self) -> None:
288
+ return None
289
+