kaustubhg73 commited on
Commit
96b50a5
·
1 Parent(s): 864223c

Add HF support

Browse files
README.md CHANGED
@@ -14,7 +14,7 @@ tags:
14
 
15
  ADAPT, the Adversarial DSA Tutor, is an OpenEnv-compliant RLVR environment for training code-generation agents on small DSA tasks. The agent receives a problem prompt, examples, and visible tests, then submits Python code. The environment runs the code against visible and hidden tests and returns reward, pass-rate metrics, execution status, and feedback.
16
 
17
- This repo now focuses on the environment layer only. Verifier work and training scripts are owned separately.
18
 
19
  ## Why This Environment
20
 
@@ -120,11 +120,15 @@ uvicorn server.app:app --host 0.0.0.0 --port 7860
120
 
121
  Useful endpoints:
122
 
 
123
  - `GET /health`
 
 
124
  - `GET /schema`
125
  - `POST /reset`
126
  - `POST /step`
127
  - `GET /state`
 
128
 
129
  Example step request:
130
 
@@ -132,12 +136,45 @@ Example step request:
132
  curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d "{\"action\":{\"code\":\"n=int(input())\nprint(n*2)\"}}"
133
  ```
134
 
 
 
 
 
 
 
135
  Validate with OpenEnv once dependencies are installed:
136
 
137
  ```powershell
138
  openenv validate .
139
  ```
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  ## Hugging Face Spaces
142
 
143
  This repo is Docker Space ready:
 
14
 
15
  ADAPT, the Adversarial DSA Tutor, is an OpenEnv-compliant RLVR environment for training code-generation agents on small DSA tasks. The agent receives a problem prompt, examples, and visible tests, then submits Python code. The environment runs the code against visible and hidden tests and returns reward, pass-rate metrics, execution status, and feedback.
16
 
17
+ This repo includes the environment, verifier helpers, a baseline inference runner, and a GRPO training entrypoint so the full submission flow can be exercised from one codebase.
18
 
19
  ## Why This Environment
20
 
 
120
 
121
  Useful endpoints:
122
 
123
+ - `GET /`
124
  - `GET /health`
125
+ - `GET /metadata`
126
+ - `GET /tasks`
127
  - `GET /schema`
128
  - `POST /reset`
129
  - `POST /step`
130
  - `GET /state`
131
+ - `POST /mcp`
132
 
133
  Example step request:
134
 
 
136
  curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d "{\"action\":{\"code\":\"n=int(input())\nprint(n*2)\"}}"
137
  ```
138
 
139
+ You can also send the raw action body:
140
+
141
+ ```powershell
142
+ curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d "{\"code\":\"n=int(input())\nprint(n*2)\"}"
143
+ ```
144
+
145
  Validate with OpenEnv once dependencies are installed:
146
 
147
  ```powershell
148
  openenv validate .
149
  ```
150
 
151
+ Run the verifier smoke test:
152
+
153
+ ```powershell
154
+ python scripts\test_verifier.py
155
+ ```
156
+
157
+ Run the environment smoke test:
158
+
159
+ ```powershell
160
+ python scripts\test_env.py
161
+ ```
162
+
163
+ Run the baseline model loop:
164
+
165
+ ```powershell
166
+ $env:HF_TOKEN="..."
167
+ $env:API_BASE_URL="https://router.huggingface.co/v1"
168
+ $env:MODEL_NAME="openai/gpt-oss-120b"
169
+ python inference.py
170
+ ```
171
+
172
+ Run GRPO training:
173
+
174
+ ```powershell
175
+ python training\train_grpo.py --output-dir outputs_v2 --bf16
176
+ ```
177
+
178
  ## Hugging Face Spaces
179
 
180
  This repo is Docker Space ready:
client.py CHANGED
@@ -21,7 +21,7 @@ class AdaptEnvClient:
21
  return response.json()
22
 
23
  def step(self, code: str) -> dict[str, Any]:
24
- response = self._client.post("/step", json={"action": AdaptAction(code=code).model_dump()})
25
  response.raise_for_status()
26
  return response.json()
27
 
 
21
  return response.json()
22
 
23
  def step(self, code: str) -> dict[str, Any]:
24
+ response = self._client.post("/step", json=AdaptAction(code=code).model_dump())
25
  response.raise_for_status()
26
  return response.json()
27
 
env/adapt_env.py CHANGED
@@ -1,15 +1,26 @@
1
  from __future__ import annotations
2
 
3
  import ast
4
- from typing import Any
5
  from uuid import uuid4
6
 
7
- from openenv.core.env_server.interfaces import Environment
8
-
9
  from env.executor import run_code
10
  from env.test_cases import get_test_cases, load_problem, split_test_cases
11
  from models import AdaptAction, AdaptObservation, AdaptState
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  FORBIDDEN_IMPORTS = {"os", "pathlib", "shutil", "socket", "subprocess"}
15
  DIFFICULTY_LABELS = {1: "easy", 2: "medium", 3: "hard"}
 
1
  from __future__ import annotations
2
 
3
  import ast
4
+ from typing import Any, Generic, TypeVar
5
  from uuid import uuid4
6
 
 
 
7
  from env.executor import run_code
8
  from env.test_cases import get_test_cases, load_problem, split_test_cases
9
  from models import AdaptAction, AdaptObservation, AdaptState
10
 
11
+ try:
12
+ from openenv.core.env_server.interfaces import Environment
13
+ except ImportError:
14
+ ActionT = TypeVar("ActionT")
15
+ ObservationT = TypeVar("ObservationT")
16
+ StateT = TypeVar("StateT")
17
+
18
+ class Environment(Generic[ActionT, ObservationT, StateT]):
19
+ SUPPORTS_CONCURRENT_SESSIONS = False
20
+
21
+ def __init__(self) -> None:
22
+ pass
23
+
24
 
25
  FORBIDDEN_IMPORTS = {"os", "pathlib", "shutil", "socket", "subprocess"}
26
  DIFFICULTY_LABELS = {1: "easy", 2: "medium", 3: "hard"}
env/executor.py CHANGED
@@ -23,7 +23,7 @@ def run_code(code: str, input_data: str) -> dict:
23
 
24
  try:
25
  result = subprocess.run(
26
- ["python3", str(file_path)],
27
  input=input_data,
28
  text=True,
29
  capture_output=True,
 
23
 
24
  try:
25
  result = subprocess.run(
26
+ [sys.executable, str(file_path)],
27
  input=input_data,
28
  text=True,
29
  capture_output=True,
inference.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ STDOUT FORMAT (must match exactly):
3
+ [START] task=<task_name> env=adapt_dsa_tutor model=<model_name>
4
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
5
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import os
12
+ from typing import Any
13
+
14
+ from env.adapt_env import AdaptEnvironment
15
+ from env.test_cases import load_problem_bank
16
+ from models import AdaptAction
17
+
18
+ BENCHMARK = "adapt_dsa_tutor"
19
+ TASKS = [problem["problem_id"] for problem in load_problem_bank()]
20
+ SYSTEM_PROMPT = """You are solving a programming problem in Python.
21
+
22
+ You will receive:
23
+ - a problem statement
24
+ - input format
25
+ - constraints
26
+ - worked examples
27
+ - visible tests
28
+ - feedback from previous attempts
29
+
30
+ Reply with ONLY runnable Python code. The code must read from stdin and print to stdout.
31
+ Do not include markdown fences or explanations."""
32
+
33
+
34
+ def require_env(name: str, value: str | None) -> str:
35
+ if value:
36
+ return value
37
+ raise RuntimeError(f"Missing required environment variable: {name}")
38
+
39
+
40
+ def safe_log_value(value: str | None) -> str:
41
+ if not value:
42
+ return "null"
43
+ return str(value).replace("\n", "_").replace("\r", "_").replace("\t", "_").replace(" ", "_")
44
+
45
+
46
+ def extract_code(response_text: str) -> str:
47
+ text = response_text.strip()
48
+ if text.startswith("```"):
49
+ parts = text.split("\n", 1)
50
+ text = parts[1] if len(parts) > 1 else text[3:]
51
+ if text.endswith("```"):
52
+ text = text[:-3]
53
+ text = text.strip()
54
+ if text.startswith("python"):
55
+ text = text[6:].strip()
56
+ return text
57
+
58
+
59
+ def build_user_prompt(observation: dict[str, Any]) -> str:
60
+ payload = {
61
+ "problem_id": observation["problem_id"],
62
+ "difficulty": observation["difficulty"],
63
+ "problem": observation["problem"],
64
+ "input_format": observation["input_format"],
65
+ "constraints": observation["constraints"],
66
+ "examples": observation["examples"],
67
+ "visible_tests": observation["visible_tests"],
68
+ "feedback": observation["feedback"],
69
+ }
70
+ return json.dumps(payload, indent=2)
71
+
72
+
73
+ def log_start(task: str, env: str, model: str) -> None:
74
+ print(f"[START] task={task} env={env} model={model}", flush=True)
75
+
76
+
77
+ def log_step(step: int, action_str: str, reward: float, done: bool, error: str | None) -> None:
78
+ print(
79
+ f"[STEP] step={step} action={safe_log_value(action_str)} reward={reward:.2f} "
80
+ f"done={str(done).lower()} error={safe_log_value(error)}",
81
+ flush=True,
82
+ )
83
+
84
+
85
+ def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
86
+ rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
87
+ print(
88
+ f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
89
+ flush=True,
90
+ )
91
+
92
+
93
+ def run_task(task_name: str) -> float:
94
+ try:
95
+ from openai import OpenAI
96
+ except ImportError as exc:
97
+ raise RuntimeError(
98
+ "The `openai` package is required for inference runs. Install it before running inference.py."
99
+ ) from exc
100
+
101
+ api_key = require_env("HF_TOKEN", os.getenv("HF_TOKEN"))
102
+ base_url = require_env("API_BASE_URL", os.getenv("API_BASE_URL", "https://router.huggingface.co/v1"))
103
+ model_name = require_env("MODEL_NAME", os.getenv("MODEL_NAME", "openai/gpt-oss-120b"))
104
+
105
+ client = OpenAI(base_url=base_url, api_key=api_key)
106
+ env = AdaptEnvironment()
107
+ observation = env.reset(problem_id=task_name)
108
+
109
+ log_start(task_name, BENCHMARK, model_name)
110
+ rewards = []
111
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
112
+
113
+ max_steps = 3
114
+ for step_index in range(1, max_steps + 1):
115
+ messages.append({"role": "user", "content": build_user_prompt(observation.model_dump())})
116
+
117
+ try:
118
+ response = client.chat.completions.create(
119
+ model=model_name,
120
+ messages=messages,
121
+ temperature=0.0,
122
+ max_tokens=512,
123
+ )
124
+ response_text = response.choices[0].message.content or ""
125
+ messages.append({"role": "assistant", "content": response_text})
126
+
127
+ code = extract_code(response_text)
128
+ observation = env.step(AdaptAction(code=code))
129
+ rewards.append(float(observation.reward))
130
+ log_step(step_index, "submit_code", float(observation.reward), bool(observation.done), None)
131
+
132
+ if observation.pass_rate == 1.0 or observation.done:
133
+ break
134
+
135
+ except Exception as exc:
136
+ rewards.append(0.0)
137
+ log_step(step_index, "parse_error", 0.0, False, str(exc))
138
+ messages.append(
139
+ {
140
+ "role": "user",
141
+ "content": f"Your last response failed. Error: {exc}. Reply with only Python code.",
142
+ }
143
+ )
144
+
145
+ success = observation.pass_rate == 1.0
146
+ score = float(observation.reward)
147
+ log_end(success, len(rewards), score, rewards)
148
+ return score
149
+
150
+
151
+ def main() -> dict[str, float]:
152
+ scores = {}
153
+ for task in TASKS:
154
+ scores[task] = run_task(task)
155
+ return scores
156
+
157
+
158
+ if __name__ == "__main__":
159
+ main()
models.py CHANGED
@@ -2,7 +2,22 @@ from __future__ import annotations
2
 
3
  from typing import Any
4
 
5
- from openenv.core.env_server.types import Action, Observation, State
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from pydantic import Field
7
 
8
 
 
2
 
3
  from typing import Any
4
 
5
+ from pydantic import BaseModel, Field
6
+
7
+ try:
8
+ from openenv.core.env_server.types import Action, Observation, State
9
+ except ImportError:
10
+ class Action(BaseModel):
11
+ model_config = {"extra": "forbid"}
12
+
13
+ class Observation(BaseModel):
14
+ reward: float = Field(default=0.0, ge=0.0, le=1.0)
15
+ done: bool = False
16
+
17
+ class State(BaseModel):
18
+ episode_id: str = ""
19
+ step_count: int = 0
20
+
21
  from pydantic import Field
22
 
23
 
pyproject.toml CHANGED
@@ -26,5 +26,5 @@ server = "server.app:main"
26
 
27
  [tool.setuptools]
28
  include-package-data = true
29
- packages = ["env", "server"]
30
  py-modules = ["app", "client", "models"]
 
26
 
27
  [tool.setuptools]
28
  include-package-data = true
29
+ packages = ["env", "server", "verifier"]
30
  py-modules = ["app", "client", "models"]
scripts/test_env.py CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  from env.adapt_env import AdaptEnvironment
2
  from models import AdaptAction
3
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ ROOT = Path(__file__).resolve().parents[1]
7
+ if str(ROOT) not in sys.path:
8
+ sys.path.insert(0, str(ROOT))
9
+
10
  from env.adapt_env import AdaptEnvironment
11
  from models import AdaptAction
12
 
scripts/test_verifier.py CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  from verifier.verifier import verify
2
 
3
 
@@ -38,4 +47,4 @@ for name, code in [
38
  print("Pass rate:", info["pass_rate"])
39
  print("Passed:", info["passed"], "/", info["total"])
40
  print("Timeouts:", info["timeout_count"])
41
- print("Errors:", info["error_count"])
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ ROOT = Path(__file__).resolve().parents[1]
7
+ if str(ROOT) not in sys.path:
8
+ sys.path.insert(0, str(ROOT))
9
+
10
  from verifier.verifier import verify
11
 
12
 
 
47
  print("Pass rate:", info["pass_rate"])
48
  print("Passed:", info["passed"], "/", info["total"])
49
  print("Timeouts:", info["timeout_count"])
50
+ print("Errors:", info["error_count"])
server/app.py CHANGED
@@ -1,28 +1,162 @@
1
  from __future__ import annotations
2
 
3
- try:
4
- from openenv.core.env_server.http_server import create_app
5
- except Exception as exc: # pragma: no cover
6
- raise ImportError(
7
- "openenv-core>=0.2.3 is required. Install with: pip install -e ."
8
- ) from exc
9
 
10
- from env.adapt_env import AdaptEnvironment
11
- from models import AdaptAction, AdaptObservation
 
 
12
 
 
 
 
13
 
14
- app = create_app(
15
- AdaptEnvironment,
16
- AdaptAction,
17
- AdaptObservation,
18
- env_name="adapt_dsa_tutor",
19
- max_concurrent_envs=4,
20
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
22
 
23
- def main(host: str = "0.0.0.0", port: int = 7860) -> None:
24
- import uvicorn
25
 
 
 
 
 
 
 
 
 
26
  uvicorn.run(app, host=host, port=port)
27
 
28
 
 
1
  from __future__ import annotations
2
 
3
+ import argparse
4
+ from typing import Any
 
 
 
 
5
 
6
+ import uvicorn
7
+ from fastapi import Body, FastAPI, HTTPException, Request
8
+ from fastapi.responses import RedirectResponse, Response
9
+ from pydantic import BaseModel
10
 
11
+ from env.adapt_env import AdaptEnvironment
12
+ from env.test_cases import load_problem_bank
13
+ from models import AdaptAction, AdaptObservation, AdaptState
14
 
15
+ ENV_NAME = "adapt_dsa_tutor"
16
+ ENV_DESCRIPTION = (
17
+ "RL environment for DSA code generation with hidden tests, tiered problems, "
18
+ "and verifier-aware reward shaping."
 
 
19
  )
20
+ TASKS = [
21
+ {
22
+ "name": problem["problem_id"],
23
+ "difficulty": problem["difficulty"],
24
+ "description": problem["problem"],
25
+ }
26
+ for problem in load_problem_bank()
27
+ ]
28
+
29
+ app = FastAPI(title="ADAPT DSA Tutor OpenEnv", version="0.2.0")
30
+ ENV = AdaptEnvironment()
31
+
32
+
33
+ class ResetRequest(BaseModel):
34
+ seed: int | None = None
35
+ episode_id: str | None = None
36
+ problem_id: str | None = None
37
+ difficulty: str | None = None
38
+
39
+
40
+ def _metadata() -> dict[str, Any]:
41
+ return {
42
+ "name": ENV_NAME,
43
+ "description": ENV_DESCRIPTION,
44
+ "version": "0.2.0",
45
+ "tasks": TASKS,
46
+ "mode": "simulation",
47
+ }
48
+
49
+
50
+ @app.get("/")
51
+ def root() -> dict[str, Any]:
52
+ payload = _metadata()
53
+ payload["status"] = "ok"
54
+ return payload
55
+
56
+
57
+ @app.get("/web", include_in_schema=False)
58
+ def web_root() -> RedirectResponse:
59
+ return RedirectResponse(url="/", status_code=307)
60
+
61
+
62
+ @app.get("/web/", include_in_schema=False)
63
+ def web_root_slash() -> RedirectResponse:
64
+ return RedirectResponse(url="/", status_code=307)
65
+
66
+
67
+ @app.get("/favicon.ico", include_in_schema=False)
68
+ def favicon() -> Response:
69
+ return Response(status_code=204)
70
+
71
+
72
+ @app.get("/health")
73
+ def health() -> dict[str, str]:
74
+ return {"status": "healthy"}
75
+
76
+
77
+ @app.get("/metadata")
78
+ def metadata() -> dict[str, Any]:
79
+ return _metadata()
80
+
81
+
82
+ @app.get("/tasks")
83
+ def list_tasks() -> dict[str, Any]:
84
+ return {"tasks": TASKS}
85
+
86
+
87
+ @app.get("/schema")
88
+ def schema() -> dict[str, Any]:
89
+ return {
90
+ "action": AdaptAction.model_json_schema(),
91
+ "observation": AdaptObservation.model_json_schema(),
92
+ "state": AdaptState.model_json_schema(),
93
+ }
94
+
95
+
96
+ @app.post("/mcp")
97
+ def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
98
+ return {
99
+ "jsonrpc": "2.0",
100
+ "id": payload.get("id"),
101
+ "error": {
102
+ "code": -32601,
103
+ "message": "MCP methods are not implemented for this environment.",
104
+ },
105
+ }
106
+
107
+
108
+ @app.post("/reset")
109
+ def reset(request: ResetRequest | None = None) -> dict[str, Any]:
110
+ effective_request = request or ResetRequest()
111
+ observation = ENV.reset(
112
+ seed=effective_request.seed,
113
+ episode_id=effective_request.episode_id,
114
+ problem_id=effective_request.problem_id,
115
+ difficulty=effective_request.difficulty,
116
+ )
117
+ return observation.model_dump()
118
+
119
+
120
+ @app.post("/step")
121
+ async def step(request: Request) -> dict[str, Any]:
122
+ payload = await request.json()
123
+ if not isinstance(payload, dict):
124
+ raise HTTPException(status_code=422, detail="Request body must be a JSON object.")
125
+
126
+ raw_action = payload.get("action", payload)
127
+ try:
128
+ effective_action = AdaptAction.model_validate(raw_action)
129
+ except Exception as exc:
130
+ raise HTTPException(status_code=422, detail=f"Invalid action payload: {exc}") from exc
131
+
132
+ observation = ENV.step(effective_action)
133
+ return {
134
+ "observation": observation.model_dump(),
135
+ "reward": float(observation.reward),
136
+ "done": bool(observation.done),
137
+ "info": {
138
+ "feedback": observation.feedback,
139
+ "pass_rate": observation.pass_rate,
140
+ "execution_status": observation.execution_status,
141
+ },
142
+ }
143
+
144
 
145
+ @app.get("/state")
146
+ def state() -> dict[str, Any]:
147
+ if not ENV.problem:
148
+ ENV.reset()
149
+ return ENV.state.model_dump()
150
 
 
 
151
 
152
+ def main(host: str | None = None, port: int | None = None) -> None:
153
+ if host is None or port is None:
154
+ parser = argparse.ArgumentParser()
155
+ parser.add_argument("--host", default="0.0.0.0")
156
+ parser.add_argument("--port", type=int, default=7860)
157
+ args = parser.parse_args()
158
+ host = args.host if host is None else host
159
+ port = args.port if port is None else port
160
  uvicorn.run(app, host=host, port=port)
161
 
162
 
test.py CHANGED
@@ -12,7 +12,7 @@ def assert_hidden_tests_are_not_exposed(payload: dict) -> None:
12
 
13
  def main() -> None:
14
  env = AdaptEnvironment()
15
- observation = env.reset()
16
  assert isinstance(observation, AdaptObservation)
17
  assert observation.visible_tests
18
  assert observation.problem_id == "easy_double"
 
12
 
13
  def main() -> None:
14
  env = AdaptEnvironment()
15
+ observation = env.reset(problem_id="easy_double")
16
  assert isinstance(observation, AdaptObservation)
17
  assert observation.visible_tests
18
  assert observation.problem_id == "easy_double"
training/train_grpo.py CHANGED
@@ -1,167 +1,152 @@
1
- import torch
2
- <<<<<<< HEAD
3
- from unsloth import FastLanguageModel, PatchFastRL
4
- from trl import GRPOTrainer, GRPOConfig
5
- from meta_rl_dsa_solver_env import DsaEnv
6
-
7
- # 1. Patch Unsloth for RL speedups
8
- PatchFastRL("GRPO", FastLanguageModel)
9
-
10
- # 2. Load Model & Tokenizer
11
- model, tokenizer = FastLanguageModel.from_pretrained(
12
- model_name = "unsloth/Llama-3.2-3B-Instruct", # Use appropriate 2026 base
13
- max_seq_length = 2048,
14
- load_in_4bit = True,
15
- fast_inference = True,
16
- =======
17
- import numpy as np
18
- from unsloth import FastLanguageModel, PatchFastRL
19
- from trl import GRPOTrainer, GRPOConfig
20
- from meta_rl_dsa_solver_env.env.adapt_env import AdaptEnvironment
21
- from meta_rl_dsa_solver_env.models import AdaptAction
22
-
23
- # 1. Initialize Model & Speedups
24
- PatchFastRL("GRPO", FastLanguageModel)
25
-
26
- model, tokenizer = FastLanguageModel.from_pretrained(
27
- model_name = "unsloth/Llama-3.2-3B-Instruct",
28
- max_seq_length = 2048,
29
- load_in_4bit = True,
30
- >>>>>>> environment-v2
31
- )
32
-
33
- model = FastLanguageModel.get_peft_model(
34
- model,
35
- r = 16,
36
- target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
37
- lora_alpha = 16,
38
- <<<<<<< HEAD
39
- lora_dropout = 0,
40
- )
41
-
42
- # 3. Define the Reward Function (Interface for Person 2)
43
- def reward_function(prompts, completions, **kwargs) -> list[float]:
44
- """
45
- In GRPO, the reward function is called on the batch of completions.
46
- For V0, we manually trigger the Env's step logic.
47
- """
48
- env = DsaEnv()
49
- rewards = []
50
-
51
- for completion in completions:
52
- # Extract code from completion (assuming markdown tags)
53
- code = completion.split("```python")[-1].split("```")[0].strip() if "```" in completion else completion
54
- _, reward, _, _, _ = env.step(code)
55
- rewards.append(reward)
56
-
57
- return rewards
58
-
59
- # 4. Training Configuration
60
- training_args = GRPOConfig(
61
- output_dir = "./outputs",
62
- learning_rate = 5e-6,
63
- per_device_train_batch_size = 4,
64
- gradient_accumulation_steps = 4,
65
- max_prompt_length = 512,
66
- max_completion_length = 512,
67
- num_generations = 8, # Group size for GRPO
68
- logging_steps = 1,
69
- max_steps = 100, # Quick run for MVP
70
- )
71
-
72
- # 5. Initialize Trainer
73
- trainer = GRPOTrainer(
74
- model = model,
75
- reward_funcs = [reward_function],
76
- args = training_args,
77
- train_dataset = [
78
- {"prompt": "Write a function `sum_list(arr: list) -> int` that returns the sum of a list."}
79
- ] * 100, # Dummy dataset for V0 validation
80
- )
81
 
82
- if __name__ == "__main__":
83
- print("Starting V0 Training...")
84
- trainer.train()
85
- model.save_pretrained_merged("final_v0_model", tokenizer, save_method = "merged_16bit")
86
- =======
87
- )
 
 
 
 
 
 
 
 
 
88
 
89
- # 2. V2 Heuristic State Machine
 
90
  class CurriculumManager:
91
- def __init__(self):
92
- self.difficulties = ["easy", "medium", "hard"]
93
- self.current_idx = 0
94
- self.success_history = []
95
- self.window_size = 10 # Moving average window
96
 
97
- def get_current_difficulty(self):
98
  return self.difficulties[self.current_idx]
99
 
100
- def update(self, success_rate):
101
- self.success_history.append(success_rate)
102
  if len(self.success_history) > self.window_size:
103
  self.success_history.pop(0)
104
-
105
- # V2 Logic: If moving average > 70%, increase difficulty
106
- avg_success = np.mean(self.success_history)
107
- if avg_success > 0.70 and self.current_idx < len(self.difficulties) - 1:
108
  self.current_idx += 1
109
- print(f"--- HEURISTIC LEVEL UP: Moving to {self.difficulties[self.current_idx]} ---")
110
- self.success_history = [] # Reset for the new tier
111
-
112
- curriculum = CurriculumManager()
113
-
114
- # 3. V2 Reward Function with Curriculum Feedback
115
- def v2_reward_func(prompts, completions, **kwargs) -> list[float]:
116
- env = AdaptEnvironment()
117
- rewards = []
118
- successes = []
119
-
120
- current_diff = curriculum.get_current_difficulty()
121
-
122
- for completion in completions:
123
- # Load problem based on current heuristic difficulty
124
- env.reset(difficulty=current_diff)
125
-
126
- code = completion.split("```python")[-1].split("```")[0].strip() if "```" in completion else completion
127
- action = AdaptAction(code=code)
128
- obs = env.step(action)
129
-
130
- rewards.append(float(obs.reward))
131
- successes.append(1.0 if obs.pass_rate == 1.0 else 0.0)
132
-
133
- # Update the curriculum manager based on this batch
134
- batch_success_rate = np.mean(successes)
135
- curriculum.update(batch_success_rate)
136
-
137
- return rewards
138
-
139
- # 4. Dataset: Transition from single prompt to generic instruction
140
- # This forces the LLM to look at the 'problem statement' in the observation
141
- dataset = [
142
- {"prompt": "Read the problem statement and constraints carefully. Write a Python solution that reads from stdin and prints to stdout."}
143
- ] * 200 # Larger dataset for multi-tier learning
144
-
145
- # 5. Config
146
- training_args = GRPOConfig(
147
- output_dir = "./outputs_v2",
148
- learning_rate = 5e-6,
149
- per_device_train_batch_size = 1,
150
- gradient_accumulation_steps = 8, # Higher for stability during transitions
151
- num_generations = 8,
152
- max_steps = 250,
153
- bf16 = True,
154
- logging_steps = 1,
155
- )
156
-
157
- trainer = GRPOTrainer(
158
- model = model,
159
- reward_funcs = [v2_reward_func],
160
- args = training_args,
161
- train_dataset = dataset,
162
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  if __name__ == "__main__":
165
- print(f"Starting V2 Training. Initial Difficulty: {curriculum.get_current_difficulty()}")
166
- trainer.train()
167
- >>>>>>> environment-v2
 
1
+ from __future__ import annotations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ import argparse
4
+ from dataclasses import dataclass, field
5
+ from typing import Any
6
+
7
+ from env.adapt_env import AdaptEnvironment
8
+ from models import AdaptAction
9
+
10
+
11
+ def extract_code(completion: str) -> str:
12
+ text = completion.strip()
13
+ if "```python" in text:
14
+ return text.split("```python", 1)[1].split("```", 1)[0].strip()
15
+ if "```" in text:
16
+ return text.split("```", 1)[1].split("```", 1)[0].strip()
17
+ return text
18
 
19
+
20
+ @dataclass
21
  class CurriculumManager:
22
+ difficulties: list[str] = field(default_factory=lambda: ["easy", "medium", "hard"])
23
+ current_idx: int = 0
24
+ success_history: list[float] = field(default_factory=list)
25
+ window_size: int = 10
 
26
 
27
+ def current_difficulty(self) -> str:
28
  return self.difficulties[self.current_idx]
29
 
30
+ def update(self, batch_success_rate: float) -> None:
31
+ self.success_history.append(float(batch_success_rate))
32
  if len(self.success_history) > self.window_size:
33
  self.success_history.pop(0)
34
+
35
+ moving_average = sum(self.success_history) / len(self.success_history)
36
+ if moving_average > 0.70 and self.current_idx < len(self.difficulties) - 1:
 
37
  self.current_idx += 1
38
+ self.success_history.clear()
39
+ print(
40
+ f"[curriculum] promoted to {self.current_difficulty()} "
41
+ f"(moving_success={moving_average:.2f})"
42
+ )
43
+
44
+
45
+ def build_reward_func(curriculum: CurriculumManager):
46
+ def reward_func(prompts, completions, **kwargs) -> list[float]:
47
+ del prompts, kwargs
48
+ env = AdaptEnvironment()
49
+ rewards: list[float] = []
50
+ successes: list[float] = []
51
+ difficulty = curriculum.current_difficulty()
52
+
53
+ for completion in completions:
54
+ env.reset(difficulty=difficulty)
55
+ observation = env.step(AdaptAction(code=extract_code(completion)))
56
+ rewards.append(float(observation.reward))
57
+ successes.append(1.0 if observation.pass_rate == 1.0 else 0.0)
58
+
59
+ if successes:
60
+ curriculum.update(sum(successes) / len(successes))
61
+
62
+ return rewards
63
+
64
+ return reward_func
65
+
66
+
67
+ def build_dataset(size: int) -> list[dict[str, str]]:
68
+ prompt = (
69
+ "Read the problem statement carefully. "
70
+ "Write a Python solution that reads from stdin and prints to stdout."
71
+ )
72
+ return [{"prompt": prompt}] * size
73
+
74
+
75
+ def run_training(args: argparse.Namespace) -> None:
76
+ try:
77
+ from trl import GRPOConfig, GRPOTrainer
78
+ from unsloth import FastLanguageModel, PatchFastRL
79
+ except ImportError as exc:
80
+ raise RuntimeError(
81
+ "Training dependencies are missing. Install `trl` and `unsloth` before running GRPO training."
82
+ ) from exc
83
+
84
+ PatchFastRL("GRPO", FastLanguageModel)
85
+
86
+ model, tokenizer = FastLanguageModel.from_pretrained(
87
+ model_name=args.model_name,
88
+ max_seq_length=args.max_seq_length,
89
+ load_in_4bit=not args.disable_4bit,
90
+ )
91
+
92
+ model = FastLanguageModel.get_peft_model(
93
+ model,
94
+ r=args.lora_rank,
95
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
96
+ lora_alpha=args.lora_alpha,
97
+ lora_dropout=0.0,
98
+ )
99
+
100
+ curriculum = CurriculumManager()
101
+ training_args = GRPOConfig(
102
+ output_dir=args.output_dir,
103
+ learning_rate=args.learning_rate,
104
+ per_device_train_batch_size=args.batch_size,
105
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
106
+ num_generations=args.num_generations,
107
+ max_prompt_length=args.max_prompt_length,
108
+ max_completion_length=args.max_completion_length,
109
+ max_steps=args.max_steps,
110
+ logging_steps=1,
111
+ bf16=args.bf16,
112
+ )
113
+
114
+ trainer = GRPOTrainer(
115
+ model=model,
116
+ reward_funcs=[build_reward_func(curriculum)],
117
+ args=training_args,
118
+ train_dataset=build_dataset(args.dataset_size),
119
+ )
120
+ trainer.train()
121
+ model.save_pretrained(args.output_dir)
122
+ tokenizer.save_pretrained(args.output_dir)
123
+
124
+
125
+ def build_parser() -> argparse.ArgumentParser:
126
+ parser = argparse.ArgumentParser(description="GRPO training entrypoint for the ADAPT DSA environment.")
127
+ parser.add_argument("--model-name", default="unsloth/Llama-3.2-3B-Instruct")
128
+ parser.add_argument("--output-dir", default="outputs_v2")
129
+ parser.add_argument("--dataset-size", type=int, default=200)
130
+ parser.add_argument("--max-steps", type=int, default=250)
131
+ parser.add_argument("--batch-size", type=int, default=1)
132
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
133
+ parser.add_argument("--num-generations", type=int, default=8)
134
+ parser.add_argument("--max-seq-length", type=int, default=2048)
135
+ parser.add_argument("--max-prompt-length", type=int, default=512)
136
+ parser.add_argument("--max-completion-length", type=int, default=512)
137
+ parser.add_argument("--learning-rate", type=float, default=5e-6)
138
+ parser.add_argument("--lora-rank", type=int, default=16)
139
+ parser.add_argument("--lora-alpha", type=int, default=16)
140
+ parser.add_argument("--disable-4bit", action="store_true")
141
+ parser.add_argument("--bf16", action="store_true")
142
+ return parser
143
+
144
+
145
+ def main(argv: list[str] | None = None) -> None:
146
+ parser = build_parser()
147
+ args = parser.parse_args(argv)
148
+ run_training(args)
149
+
150
 
151
  if __name__ == "__main__":
152
+ main()
 
 
verifier/sandbox.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import subprocess
 
3
  import tempfile
4
 
5
 
@@ -11,7 +12,7 @@ def run_code(code: str, stdin: str, timeout: int = 2):
11
  path = f.name
12
 
13
  result = subprocess.run(
14
- ["python3", path],
15
  input=stdin,
16
  text=True,
17
  capture_output=True,
@@ -31,4 +32,4 @@ def run_code(code: str, stdin: str, timeout: int = 2):
31
 
32
  finally:
33
  if path and os.path.exists(path):
34
- os.remove(path)
 
1
  import os
2
  import subprocess
3
+ import sys
4
  import tempfile
5
 
6
 
 
12
  path = f.name
13
 
14
  result = subprocess.run(
15
+ [sys.executable, path],
16
  input=stdin,
17
  text=True,
18
  capture_output=True,
 
32
 
33
  finally:
34
  if path and os.path.exists(path):
35
+ os.remove(path)
verifier/verifier.py CHANGED
@@ -5,7 +5,13 @@ from verifier.metrics import compute_pass_rate
5
  def verify(code: str, test_cases):
6
  results = []
7
 
8
- for stdin, expected in test_cases:
 
 
 
 
 
 
9
  ok, output = run_code(code, stdin)
10
 
11
  passed = ok and output.strip() == expected.strip()
@@ -23,4 +29,4 @@ def verify(code: str, test_cases):
23
  return reward, {
24
  **metrics,
25
  "results": results,
26
- }
 
5
  def verify(code: str, test_cases):
6
  results = []
7
 
8
+ for test_case in test_cases:
9
+ if isinstance(test_case, dict):
10
+ stdin = str(test_case.get("input", ""))
11
+ expected = str(test_case.get("output", ""))
12
+ else:
13
+ stdin, expected = test_case
14
+
15
  ok, output = run_code(code, stdin)
16
 
17
  passed = ok and output.strip() == expected.strip()
 
29
  return reward, {
30
  **metrics,
31
  "results": results,
32
+ }