s-shah4 commited on
Commit
1b7b2a4
·
1 Parent(s): c790f8d

Add ADAPT V0 environment

Browse files
Files changed (2) hide show
  1. README.md +81 -1
  2. environment.py +80 -0
README.md CHANGED
@@ -1 +1,81 @@
1
- # meta-rl-dsa-solver
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # meta-rl-dsa-solver
2
+
3
+ ADAPT (Adversarial DSA Tutor) is a minimal reinforcement learning environment for coding tasks. The current V0 environment is a pure Python class with no API dependency, so it can be used directly from a training loop with `env.reset()` and `env.step(...)`.
4
+
5
+ ## Current V0
6
+
7
+ - Fixed DSA problem: given an integer `n`, return `n * 2`
8
+ - Single test input: `5`
9
+ - Expected output: `10`
10
+ - Binary reward: `1.0` for correct output, `0.0` otherwise
11
+ - Subprocess execution with a 2 second timeout
12
+
13
+ ## Run a Smoke Test
14
+
15
+ From this directory:
16
+
17
+ ```powershell
18
+ cd C:\Users\kaust\PycharmProjects\meta-rl-dsa-solver
19
+ python3 -c "from environment import AdaptEnv; env=AdaptEnv(); print(env.reset()); print(env.step('n=int(input()); print(n*2)'))"
20
+ ```
21
+
22
+ Expected reward:
23
+
24
+ ```text
25
+ 1.0
26
+ ```
27
+
28
+ ## Use in Python
29
+
30
+ ```python
31
+ from environment import AdaptEnv
32
+
33
+ env = AdaptEnv()
34
+
35
+ obs = env.reset()
36
+ print(obs)
37
+
38
+ code = "n=int(input()); print(n*2)"
39
+ result = env.step(code)
40
+
41
+ print(result)
42
+ assert result["reward"] == 1.0
43
+ ```
44
+
45
+ ## Check Failure Cases
46
+
47
+ Wrong answer:
48
+
49
+ ```powershell
50
+ python3 -c "from environment import AdaptEnv; env=AdaptEnv(); env.reset(); print(env.step('print(0)'))"
51
+ ```
52
+
53
+ Timeout:
54
+
55
+ ```powershell
56
+ python3 -c "from environment import AdaptEnv; env=AdaptEnv(); env.reset(); print(env.step('while True: pass'))"
57
+ ```
58
+
59
+ ## Environment Contract
60
+
61
+ `reset()` returns:
62
+
63
+ ```python
64
+ {
65
+ "problem": "Given an integer n, return n * 2",
66
+ "input": "5",
67
+ }
68
+ ```
69
+
70
+ `step(action: str)` returns:
71
+
72
+ ```python
73
+ {
74
+ "observation": "<program output or error>",
75
+ "reward": 1.0,
76
+ "done": True,
77
+ "info": {},
78
+ }
79
+ ```
80
+
81
+ The implementation keeps the verifier pluggable so later versions can replace the single expected-output check with hidden tests, randomized inputs, or adaptive curriculum logic.
environment.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import subprocess
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Any, Callable
7
+
8
+
9
+ class AdaptEnv:
10
+ def __init__(
11
+ self,
12
+ verifier: Callable[[str, str], tuple[float, dict[str, Any]]] | None = None,
13
+ ):
14
+ self.verifier = verifier
15
+ self.problem = ""
16
+ self.current_input = ""
17
+ self.expected_output = ""
18
+ self.step_count = 0
19
+ self.last_output = ""
20
+
21
+ def reset(self) -> dict[str, str]:
22
+ self.problem = "Given an integer n, return n * 2"
23
+ self.current_input = "5"
24
+ self.expected_output = "10"
25
+ self.step_count = 0
26
+ self.last_output = ""
27
+
28
+ return {
29
+ "problem": self.problem,
30
+ "input": self.current_input,
31
+ }
32
+
33
+ def step(self, action: str) -> dict[str, Any]:
34
+ if not self.problem:
35
+ self.reset()
36
+
37
+ self.step_count += 1
38
+ output = self._run_code(action)
39
+ reward = self._compute_reward(output)
40
+ self.last_output = output
41
+
42
+ return {
43
+ "observation": output,
44
+ "reward": reward,
45
+ "done": True,
46
+ "info": {},
47
+ }
48
+
49
+ def _run_code(self, code: str) -> str:
50
+ with tempfile.TemporaryDirectory() as tmpdir:
51
+ file_path = Path(tmpdir) / "submission.py"
52
+ file_path.write_text(code, encoding="utf-8")
53
+
54
+ try:
55
+ result = subprocess.run(
56
+ ["python3", str(file_path)],
57
+ input=self.current_input,
58
+ text=True,
59
+ capture_output=True,
60
+ timeout=2,
61
+ )
62
+ except subprocess.TimeoutExpired:
63
+ return "ERROR: timeout"
64
+ except Exception as exc:
65
+ return f"ERROR: {exc}"
66
+
67
+ if result.returncode != 0:
68
+ stderr = result.stderr.strip()
69
+ return f"ERROR: {stderr or 'runtime error'}"
70
+
71
+ return result.stdout.strip()
72
+
73
+ def _compute_reward(self, output: str) -> float:
74
+ if self.verifier is not None:
75
+ reward, _info = self.verifier(output, self.expected_output)
76
+ return reward
77
+
78
+ if output == self.expected_output:
79
+ return 1.0
80
+ return 0.0