Spaces:
Running
Running
Commit ·
96b50a5
1
Parent(s): 864223c
Add HF support
Browse files- README.md +38 -1
- client.py +1 -1
- env/adapt_env.py +14 -3
- env/executor.py +1 -1
- inference.py +159 -0
- models.py +16 -1
- pyproject.toml +1 -1
- scripts/test_env.py +9 -0
- scripts/test_verifier.py +10 -1
- server/app.py +150 -16
- test.py +1 -1
- training/train_grpo.py +141 -156
- verifier/sandbox.py +3 -2
- verifier/verifier.py +8 -2
README.md
CHANGED
|
@@ -14,7 +14,7 @@ tags:
|
|
| 14 |
|
| 15 |
ADAPT, the Adversarial DSA Tutor, is an OpenEnv-compliant RLVR environment for training code-generation agents on small DSA tasks. The agent receives a problem prompt, examples, and visible tests, then submits Python code. The environment runs the code against visible and hidden tests and returns reward, pass-rate metrics, execution status, and feedback.
|
| 16 |
|
| 17 |
-
This repo
|
| 18 |
|
| 19 |
## Why This Environment
|
| 20 |
|
|
@@ -120,11 +120,15 @@ uvicorn server.app:app --host 0.0.0.0 --port 7860
|
|
| 120 |
|
| 121 |
Useful endpoints:
|
| 122 |
|
|
|
|
| 123 |
- `GET /health`
|
|
|
|
|
|
|
| 124 |
- `GET /schema`
|
| 125 |
- `POST /reset`
|
| 126 |
- `POST /step`
|
| 127 |
- `GET /state`
|
|
|
|
| 128 |
|
| 129 |
Example step request:
|
| 130 |
|
|
@@ -132,12 +136,45 @@ Example step request:
|
|
| 132 |
curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d "{\"action\":{\"code\":\"n=int(input())\nprint(n*2)\"}}"
|
| 133 |
```
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
Validate with OpenEnv once dependencies are installed:
|
| 136 |
|
| 137 |
```powershell
|
| 138 |
openenv validate .
|
| 139 |
```
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
## Hugging Face Spaces
|
| 142 |
|
| 143 |
This repo is Docker Space ready:
|
|
|
|
| 14 |
|
| 15 |
ADAPT, the Adversarial DSA Tutor, is an OpenEnv-compliant RLVR environment for training code-generation agents on small DSA tasks. The agent receives a problem prompt, examples, and visible tests, then submits Python code. The environment runs the code against visible and hidden tests and returns reward, pass-rate metrics, execution status, and feedback.
|
| 16 |
|
| 17 |
+
This repo includes the environment, verifier helpers, a baseline inference runner, and a GRPO training entrypoint so the full submission flow can be exercised from one codebase.
|
| 18 |
|
| 19 |
## Why This Environment
|
| 20 |
|
|
|
|
| 120 |
|
| 121 |
Useful endpoints:
|
| 122 |
|
| 123 |
+
- `GET /`
|
| 124 |
- `GET /health`
|
| 125 |
+
- `GET /metadata`
|
| 126 |
+
- `GET /tasks`
|
| 127 |
- `GET /schema`
|
| 128 |
- `POST /reset`
|
| 129 |
- `POST /step`
|
| 130 |
- `GET /state`
|
| 131 |
+
- `POST /mcp`
|
| 132 |
|
| 133 |
Example step request:
|
| 134 |
|
|
|
|
| 136 |
curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d "{\"action\":{\"code\":\"n=int(input())\nprint(n*2)\"}}"
|
| 137 |
```
|
| 138 |
|
| 139 |
+
You can also send the raw action body:
|
| 140 |
+
|
| 141 |
+
```powershell
|
| 142 |
+
curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d "{\"code\":\"n=int(input())\nprint(n*2)\"}"
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
Validate with OpenEnv once dependencies are installed:
|
| 146 |
|
| 147 |
```powershell
|
| 148 |
openenv validate .
|
| 149 |
```
|
| 150 |
|
| 151 |
+
Run the verifier smoke test:
|
| 152 |
+
|
| 153 |
+
```powershell
|
| 154 |
+
python scripts\test_verifier.py
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
Run the environment smoke test:
|
| 158 |
+
|
| 159 |
+
```powershell
|
| 160 |
+
python scripts\test_env.py
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
Run the baseline model loop:
|
| 164 |
+
|
| 165 |
+
```powershell
|
| 166 |
+
$env:HF_TOKEN="..."
|
| 167 |
+
$env:API_BASE_URL="https://router.huggingface.co/v1"
|
| 168 |
+
$env:MODEL_NAME="openai/gpt-oss-120b"
|
| 169 |
+
python inference.py
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
Run GRPO training:
|
| 173 |
+
|
| 174 |
+
```powershell
|
| 175 |
+
python training\train_grpo.py --output-dir outputs_v2 --bf16
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
## Hugging Face Spaces
|
| 179 |
|
| 180 |
This repo is Docker Space ready:
|
client.py
CHANGED
|
@@ -21,7 +21,7 @@ class AdaptEnvClient:
|
|
| 21 |
return response.json()
|
| 22 |
|
| 23 |
def step(self, code: str) -> dict[str, Any]:
|
| 24 |
-
response = self._client.post("/step", json=
|
| 25 |
response.raise_for_status()
|
| 26 |
return response.json()
|
| 27 |
|
|
|
|
| 21 |
return response.json()
|
| 22 |
|
| 23 |
def step(self, code: str) -> dict[str, Any]:
|
| 24 |
+
response = self._client.post("/step", json=AdaptAction(code=code).model_dump())
|
| 25 |
response.raise_for_status()
|
| 26 |
return response.json()
|
| 27 |
|
env/adapt_env.py
CHANGED
|
@@ -1,15 +1,26 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import ast
|
| 4 |
-
from typing import Any
|
| 5 |
from uuid import uuid4
|
| 6 |
|
| 7 |
-
from openenv.core.env_server.interfaces import Environment
|
| 8 |
-
|
| 9 |
from env.executor import run_code
|
| 10 |
from env.test_cases import get_test_cases, load_problem, split_test_cases
|
| 11 |
from models import AdaptAction, AdaptObservation, AdaptState
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
FORBIDDEN_IMPORTS = {"os", "pathlib", "shutil", "socket", "subprocess"}
|
| 15 |
DIFFICULTY_LABELS = {1: "easy", 2: "medium", 3: "hard"}
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import ast
|
| 4 |
+
from typing import Any, Generic, TypeVar
|
| 5 |
from uuid import uuid4
|
| 6 |
|
|
|
|
|
|
|
| 7 |
from env.executor import run_code
|
| 8 |
from env.test_cases import get_test_cases, load_problem, split_test_cases
|
| 9 |
from models import AdaptAction, AdaptObservation, AdaptState
|
| 10 |
|
| 11 |
+
try:
|
| 12 |
+
from openenv.core.env_server.interfaces import Environment
|
| 13 |
+
except ImportError:
|
| 14 |
+
ActionT = TypeVar("ActionT")
|
| 15 |
+
ObservationT = TypeVar("ObservationT")
|
| 16 |
+
StateT = TypeVar("StateT")
|
| 17 |
+
|
| 18 |
+
class Environment(Generic[ActionT, ObservationT, StateT]):
|
| 19 |
+
SUPPORTS_CONCURRENT_SESSIONS = False
|
| 20 |
+
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
|
| 25 |
FORBIDDEN_IMPORTS = {"os", "pathlib", "shutil", "socket", "subprocess"}
|
| 26 |
DIFFICULTY_LABELS = {1: "easy", 2: "medium", 3: "hard"}
|
env/executor.py
CHANGED
|
@@ -23,7 +23,7 @@ def run_code(code: str, input_data: str) -> dict:
|
|
| 23 |
|
| 24 |
try:
|
| 25 |
result = subprocess.run(
|
| 26 |
-
[
|
| 27 |
input=input_data,
|
| 28 |
text=True,
|
| 29 |
capture_output=True,
|
|
|
|
| 23 |
|
| 24 |
try:
|
| 25 |
result = subprocess.run(
|
| 26 |
+
[sys.executable, str(file_path)],
|
| 27 |
input=input_data,
|
| 28 |
text=True,
|
| 29 |
capture_output=True,
|
inference.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
STDOUT FORMAT (must match exactly):
|
| 3 |
+
[START] task=<task_name> env=adapt_dsa_tutor model=<model_name>
|
| 4 |
+
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 5 |
+
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
from env.adapt_env import AdaptEnvironment
|
| 15 |
+
from env.test_cases import load_problem_bank
|
| 16 |
+
from models import AdaptAction
|
| 17 |
+
|
| 18 |
+
BENCHMARK = "adapt_dsa_tutor"
|
| 19 |
+
TASKS = [problem["problem_id"] for problem in load_problem_bank()]
|
| 20 |
+
SYSTEM_PROMPT = """You are solving a programming problem in Python.
|
| 21 |
+
|
| 22 |
+
You will receive:
|
| 23 |
+
- a problem statement
|
| 24 |
+
- input format
|
| 25 |
+
- constraints
|
| 26 |
+
- worked examples
|
| 27 |
+
- visible tests
|
| 28 |
+
- feedback from previous attempts
|
| 29 |
+
|
| 30 |
+
Reply with ONLY runnable Python code. The code must read from stdin and print to stdout.
|
| 31 |
+
Do not include markdown fences or explanations."""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def require_env(name: str, value: str | None) -> str:
|
| 35 |
+
if value:
|
| 36 |
+
return value
|
| 37 |
+
raise RuntimeError(f"Missing required environment variable: {name}")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def safe_log_value(value: str | None) -> str:
|
| 41 |
+
if not value:
|
| 42 |
+
return "null"
|
| 43 |
+
return str(value).replace("\n", "_").replace("\r", "_").replace("\t", "_").replace(" ", "_")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def extract_code(response_text: str) -> str:
|
| 47 |
+
text = response_text.strip()
|
| 48 |
+
if text.startswith("```"):
|
| 49 |
+
parts = text.split("\n", 1)
|
| 50 |
+
text = parts[1] if len(parts) > 1 else text[3:]
|
| 51 |
+
if text.endswith("```"):
|
| 52 |
+
text = text[:-3]
|
| 53 |
+
text = text.strip()
|
| 54 |
+
if text.startswith("python"):
|
| 55 |
+
text = text[6:].strip()
|
| 56 |
+
return text
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def build_user_prompt(observation: dict[str, Any]) -> str:
|
| 60 |
+
payload = {
|
| 61 |
+
"problem_id": observation["problem_id"],
|
| 62 |
+
"difficulty": observation["difficulty"],
|
| 63 |
+
"problem": observation["problem"],
|
| 64 |
+
"input_format": observation["input_format"],
|
| 65 |
+
"constraints": observation["constraints"],
|
| 66 |
+
"examples": observation["examples"],
|
| 67 |
+
"visible_tests": observation["visible_tests"],
|
| 68 |
+
"feedback": observation["feedback"],
|
| 69 |
+
}
|
| 70 |
+
return json.dumps(payload, indent=2)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 74 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def log_step(step: int, action_str: str, reward: float, done: bool, error: str | None) -> None:
|
| 78 |
+
print(
|
| 79 |
+
f"[STEP] step={step} action={safe_log_value(action_str)} reward={reward:.2f} "
|
| 80 |
+
f"done={str(done).lower()} error={safe_log_value(error)}",
|
| 81 |
+
flush=True,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
|
| 86 |
+
rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
|
| 87 |
+
print(
|
| 88 |
+
f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
|
| 89 |
+
flush=True,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def run_task(task_name: str) -> float:
|
| 94 |
+
try:
|
| 95 |
+
from openai import OpenAI
|
| 96 |
+
except ImportError as exc:
|
| 97 |
+
raise RuntimeError(
|
| 98 |
+
"The `openai` package is required for inference runs. Install it before running inference.py."
|
| 99 |
+
) from exc
|
| 100 |
+
|
| 101 |
+
api_key = require_env("HF_TOKEN", os.getenv("HF_TOKEN"))
|
| 102 |
+
base_url = require_env("API_BASE_URL", os.getenv("API_BASE_URL", "https://router.huggingface.co/v1"))
|
| 103 |
+
model_name = require_env("MODEL_NAME", os.getenv("MODEL_NAME", "openai/gpt-oss-120b"))
|
| 104 |
+
|
| 105 |
+
client = OpenAI(base_url=base_url, api_key=api_key)
|
| 106 |
+
env = AdaptEnvironment()
|
| 107 |
+
observation = env.reset(problem_id=task_name)
|
| 108 |
+
|
| 109 |
+
log_start(task_name, BENCHMARK, model_name)
|
| 110 |
+
rewards = []
|
| 111 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 112 |
+
|
| 113 |
+
max_steps = 3
|
| 114 |
+
for step_index in range(1, max_steps + 1):
|
| 115 |
+
messages.append({"role": "user", "content": build_user_prompt(observation.model_dump())})
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
response = client.chat.completions.create(
|
| 119 |
+
model=model_name,
|
| 120 |
+
messages=messages,
|
| 121 |
+
temperature=0.0,
|
| 122 |
+
max_tokens=512,
|
| 123 |
+
)
|
| 124 |
+
response_text = response.choices[0].message.content or ""
|
| 125 |
+
messages.append({"role": "assistant", "content": response_text})
|
| 126 |
+
|
| 127 |
+
code = extract_code(response_text)
|
| 128 |
+
observation = env.step(AdaptAction(code=code))
|
| 129 |
+
rewards.append(float(observation.reward))
|
| 130 |
+
log_step(step_index, "submit_code", float(observation.reward), bool(observation.done), None)
|
| 131 |
+
|
| 132 |
+
if observation.pass_rate == 1.0 or observation.done:
|
| 133 |
+
break
|
| 134 |
+
|
| 135 |
+
except Exception as exc:
|
| 136 |
+
rewards.append(0.0)
|
| 137 |
+
log_step(step_index, "parse_error", 0.0, False, str(exc))
|
| 138 |
+
messages.append(
|
| 139 |
+
{
|
| 140 |
+
"role": "user",
|
| 141 |
+
"content": f"Your last response failed. Error: {exc}. Reply with only Python code.",
|
| 142 |
+
}
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
success = observation.pass_rate == 1.0
|
| 146 |
+
score = float(observation.reward)
|
| 147 |
+
log_end(success, len(rewards), score, rewards)
|
| 148 |
+
return score
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def main() -> dict[str, float]:
|
| 152 |
+
scores = {}
|
| 153 |
+
for task in TASKS:
|
| 154 |
+
scores[task] = run_task(task)
|
| 155 |
+
return scores
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
main()
|
models.py
CHANGED
|
@@ -2,7 +2,22 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from pydantic import Field
|
| 7 |
|
| 8 |
|
|
|
|
| 2 |
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 9 |
+
except ImportError:
|
| 10 |
+
class Action(BaseModel):
|
| 11 |
+
model_config = {"extra": "forbid"}
|
| 12 |
+
|
| 13 |
+
class Observation(BaseModel):
|
| 14 |
+
reward: float = Field(default=0.0, ge=0.0, le=1.0)
|
| 15 |
+
done: bool = False
|
| 16 |
+
|
| 17 |
+
class State(BaseModel):
|
| 18 |
+
episode_id: str = ""
|
| 19 |
+
step_count: int = 0
|
| 20 |
+
|
| 21 |
from pydantic import Field
|
| 22 |
|
| 23 |
|
pyproject.toml
CHANGED
|
@@ -26,5 +26,5 @@ server = "server.app:main"
|
|
| 26 |
|
| 27 |
[tool.setuptools]
|
| 28 |
include-package-data = true
|
| 29 |
-
packages = ["env", "server"]
|
| 30 |
py-modules = ["app", "client", "models"]
|
|
|
|
| 26 |
|
| 27 |
[tool.setuptools]
|
| 28 |
include-package-data = true
|
| 29 |
+
packages = ["env", "server", "verifier"]
|
| 30 |
py-modules = ["app", "client", "models"]
|
scripts/test_env.py
CHANGED
|
@@ -1,3 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from env.adapt_env import AdaptEnvironment
|
| 2 |
from models import AdaptAction
|
| 3 |
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 7 |
+
if str(ROOT) not in sys.path:
|
| 8 |
+
sys.path.insert(0, str(ROOT))
|
| 9 |
+
|
| 10 |
from env.adapt_env import AdaptEnvironment
|
| 11 |
from models import AdaptAction
|
| 12 |
|
scripts/test_verifier.py
CHANGED
|
@@ -1,3 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from verifier.verifier import verify
|
| 2 |
|
| 3 |
|
|
@@ -38,4 +47,4 @@ for name, code in [
|
|
| 38 |
print("Pass rate:", info["pass_rate"])
|
| 39 |
print("Passed:", info["passed"], "/", info["total"])
|
| 40 |
print("Timeouts:", info["timeout_count"])
|
| 41 |
-
print("Errors:", info["error_count"])
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 7 |
+
if str(ROOT) not in sys.path:
|
| 8 |
+
sys.path.insert(0, str(ROOT))
|
| 9 |
+
|
| 10 |
from verifier.verifier import verify
|
| 11 |
|
| 12 |
|
|
|
|
| 47 |
print("Pass rate:", info["pass_rate"])
|
| 48 |
print("Passed:", info["passed"], "/", info["total"])
|
| 49 |
print("Timeouts:", info["timeout_count"])
|
| 50 |
+
print("Errors:", info["error_count"])
|
server/app.py
CHANGED
|
@@ -1,28 +1,162 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
except Exception as exc: # pragma: no cover
|
| 6 |
-
raise ImportError(
|
| 7 |
-
"openenv-core>=0.2.3 is required. Install with: pip install -e ."
|
| 8 |
-
) from exc
|
| 9 |
|
| 10 |
-
|
| 11 |
-
from
|
|
|
|
|
|
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
env_name="adapt_dsa_tutor",
|
| 19 |
-
max_concurrent_envs=4,
|
| 20 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
def main(host: str = "0.0.0.0", port: int = 7860) -> None:
|
| 24 |
-
import uvicorn
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
uvicorn.run(app, host=host, port=port)
|
| 27 |
|
| 28 |
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import argparse
|
| 4 |
+
from typing import Any
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
import uvicorn
|
| 7 |
+
from fastapi import Body, FastAPI, HTTPException, Request
|
| 8 |
+
from fastapi.responses import RedirectResponse, Response
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
|
| 11 |
+
from env.adapt_env import AdaptEnvironment
|
| 12 |
+
from env.test_cases import load_problem_bank
|
| 13 |
+
from models import AdaptAction, AdaptObservation, AdaptState
|
| 14 |
|
| 15 |
+
ENV_NAME = "adapt_dsa_tutor"
|
| 16 |
+
ENV_DESCRIPTION = (
|
| 17 |
+
"RL environment for DSA code generation with hidden tests, tiered problems, "
|
| 18 |
+
"and verifier-aware reward shaping."
|
|
|
|
|
|
|
| 19 |
)
|
| 20 |
+
TASKS = [
|
| 21 |
+
{
|
| 22 |
+
"name": problem["problem_id"],
|
| 23 |
+
"difficulty": problem["difficulty"],
|
| 24 |
+
"description": problem["problem"],
|
| 25 |
+
}
|
| 26 |
+
for problem in load_problem_bank()
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
app = FastAPI(title="ADAPT DSA Tutor OpenEnv", version="0.2.0")
|
| 30 |
+
ENV = AdaptEnvironment()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ResetRequest(BaseModel):
|
| 34 |
+
seed: int | None = None
|
| 35 |
+
episode_id: str | None = None
|
| 36 |
+
problem_id: str | None = None
|
| 37 |
+
difficulty: str | None = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _metadata() -> dict[str, Any]:
|
| 41 |
+
return {
|
| 42 |
+
"name": ENV_NAME,
|
| 43 |
+
"description": ENV_DESCRIPTION,
|
| 44 |
+
"version": "0.2.0",
|
| 45 |
+
"tasks": TASKS,
|
| 46 |
+
"mode": "simulation",
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@app.get("/")
|
| 51 |
+
def root() -> dict[str, Any]:
|
| 52 |
+
payload = _metadata()
|
| 53 |
+
payload["status"] = "ok"
|
| 54 |
+
return payload
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@app.get("/web", include_in_schema=False)
|
| 58 |
+
def web_root() -> RedirectResponse:
|
| 59 |
+
return RedirectResponse(url="/", status_code=307)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@app.get("/web/", include_in_schema=False)
|
| 63 |
+
def web_root_slash() -> RedirectResponse:
|
| 64 |
+
return RedirectResponse(url="/", status_code=307)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@app.get("/favicon.ico", include_in_schema=False)
|
| 68 |
+
def favicon() -> Response:
|
| 69 |
+
return Response(status_code=204)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@app.get("/health")
|
| 73 |
+
def health() -> dict[str, str]:
|
| 74 |
+
return {"status": "healthy"}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@app.get("/metadata")
|
| 78 |
+
def metadata() -> dict[str, Any]:
|
| 79 |
+
return _metadata()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@app.get("/tasks")
|
| 83 |
+
def list_tasks() -> dict[str, Any]:
|
| 84 |
+
return {"tasks": TASKS}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@app.get("/schema")
|
| 88 |
+
def schema() -> dict[str, Any]:
|
| 89 |
+
return {
|
| 90 |
+
"action": AdaptAction.model_json_schema(),
|
| 91 |
+
"observation": AdaptObservation.model_json_schema(),
|
| 92 |
+
"state": AdaptState.model_json_schema(),
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@app.post("/mcp")
|
| 97 |
+
def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
|
| 98 |
+
return {
|
| 99 |
+
"jsonrpc": "2.0",
|
| 100 |
+
"id": payload.get("id"),
|
| 101 |
+
"error": {
|
| 102 |
+
"code": -32601,
|
| 103 |
+
"message": "MCP methods are not implemented for this environment.",
|
| 104 |
+
},
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@app.post("/reset")
|
| 109 |
+
def reset(request: ResetRequest | None = None) -> dict[str, Any]:
|
| 110 |
+
effective_request = request or ResetRequest()
|
| 111 |
+
observation = ENV.reset(
|
| 112 |
+
seed=effective_request.seed,
|
| 113 |
+
episode_id=effective_request.episode_id,
|
| 114 |
+
problem_id=effective_request.problem_id,
|
| 115 |
+
difficulty=effective_request.difficulty,
|
| 116 |
+
)
|
| 117 |
+
return observation.model_dump()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@app.post("/step")
|
| 121 |
+
async def step(request: Request) -> dict[str, Any]:
|
| 122 |
+
payload = await request.json()
|
| 123 |
+
if not isinstance(payload, dict):
|
| 124 |
+
raise HTTPException(status_code=422, detail="Request body must be a JSON object.")
|
| 125 |
+
|
| 126 |
+
raw_action = payload.get("action", payload)
|
| 127 |
+
try:
|
| 128 |
+
effective_action = AdaptAction.model_validate(raw_action)
|
| 129 |
+
except Exception as exc:
|
| 130 |
+
raise HTTPException(status_code=422, detail=f"Invalid action payload: {exc}") from exc
|
| 131 |
+
|
| 132 |
+
observation = ENV.step(effective_action)
|
| 133 |
+
return {
|
| 134 |
+
"observation": observation.model_dump(),
|
| 135 |
+
"reward": float(observation.reward),
|
| 136 |
+
"done": bool(observation.done),
|
| 137 |
+
"info": {
|
| 138 |
+
"feedback": observation.feedback,
|
| 139 |
+
"pass_rate": observation.pass_rate,
|
| 140 |
+
"execution_status": observation.execution_status,
|
| 141 |
+
},
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
|
| 145 |
+
@app.get("/state")
|
| 146 |
+
def state() -> dict[str, Any]:
|
| 147 |
+
if not ENV.problem:
|
| 148 |
+
ENV.reset()
|
| 149 |
+
return ENV.state.model_dump()
|
| 150 |
|
|
|
|
|
|
|
| 151 |
|
| 152 |
+
def main(host: str | None = None, port: int | None = None) -> None:
|
| 153 |
+
if host is None or port is None:
|
| 154 |
+
parser = argparse.ArgumentParser()
|
| 155 |
+
parser.add_argument("--host", default="0.0.0.0")
|
| 156 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 157 |
+
args = parser.parse_args()
|
| 158 |
+
host = args.host if host is None else host
|
| 159 |
+
port = args.port if port is None else port
|
| 160 |
uvicorn.run(app, host=host, port=port)
|
| 161 |
|
| 162 |
|
test.py
CHANGED
|
@@ -12,7 +12,7 @@ def assert_hidden_tests_are_not_exposed(payload: dict) -> None:
|
|
| 12 |
|
| 13 |
def main() -> None:
|
| 14 |
env = AdaptEnvironment()
|
| 15 |
-
observation = env.reset()
|
| 16 |
assert isinstance(observation, AdaptObservation)
|
| 17 |
assert observation.visible_tests
|
| 18 |
assert observation.problem_id == "easy_double"
|
|
|
|
| 12 |
|
| 13 |
def main() -> None:
|
| 14 |
env = AdaptEnvironment()
|
| 15 |
+
observation = env.reset(problem_id="easy_double")
|
| 16 |
assert isinstance(observation, AdaptObservation)
|
| 17 |
assert observation.visible_tests
|
| 18 |
assert observation.problem_id == "easy_double"
|
training/train_grpo.py
CHANGED
|
@@ -1,167 +1,152 @@
|
|
| 1 |
-
import
|
| 2 |
-
<<<<<<< HEAD
|
| 3 |
-
from unsloth import FastLanguageModel, PatchFastRL
|
| 4 |
-
from trl import GRPOTrainer, GRPOConfig
|
| 5 |
-
from meta_rl_dsa_solver_env import DsaEnv
|
| 6 |
-
|
| 7 |
-
# 1. Patch Unsloth for RL speedups
|
| 8 |
-
PatchFastRL("GRPO", FastLanguageModel)
|
| 9 |
-
|
| 10 |
-
# 2. Load Model & Tokenizer
|
| 11 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 12 |
-
model_name = "unsloth/Llama-3.2-3B-Instruct", # Use appropriate 2026 base
|
| 13 |
-
max_seq_length = 2048,
|
| 14 |
-
load_in_4bit = True,
|
| 15 |
-
fast_inference = True,
|
| 16 |
-
=======
|
| 17 |
-
import numpy as np
|
| 18 |
-
from unsloth import FastLanguageModel, PatchFastRL
|
| 19 |
-
from trl import GRPOTrainer, GRPOConfig
|
| 20 |
-
from meta_rl_dsa_solver_env.env.adapt_env import AdaptEnvironment
|
| 21 |
-
from meta_rl_dsa_solver_env.models import AdaptAction
|
| 22 |
-
|
| 23 |
-
# 1. Initialize Model & Speedups
|
| 24 |
-
PatchFastRL("GRPO", FastLanguageModel)
|
| 25 |
-
|
| 26 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 27 |
-
model_name = "unsloth/Llama-3.2-3B-Instruct",
|
| 28 |
-
max_seq_length = 2048,
|
| 29 |
-
load_in_4bit = True,
|
| 30 |
-
>>>>>>> environment-v2
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
model = FastLanguageModel.get_peft_model(
|
| 34 |
-
model,
|
| 35 |
-
r = 16,
|
| 36 |
-
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 37 |
-
lora_alpha = 16,
|
| 38 |
-
<<<<<<< HEAD
|
| 39 |
-
lora_dropout = 0,
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
# 3. Define the Reward Function (Interface for Person 2)
|
| 43 |
-
def reward_function(prompts, completions, **kwargs) -> list[float]:
|
| 44 |
-
"""
|
| 45 |
-
In GRPO, the reward function is called on the batch of completions.
|
| 46 |
-
For V0, we manually trigger the Env's step logic.
|
| 47 |
-
"""
|
| 48 |
-
env = DsaEnv()
|
| 49 |
-
rewards = []
|
| 50 |
-
|
| 51 |
-
for completion in completions:
|
| 52 |
-
# Extract code from completion (assuming markdown tags)
|
| 53 |
-
code = completion.split("```python")[-1].split("```")[0].strip() if "```" in completion else completion
|
| 54 |
-
_, reward, _, _, _ = env.step(code)
|
| 55 |
-
rewards.append(reward)
|
| 56 |
-
|
| 57 |
-
return rewards
|
| 58 |
-
|
| 59 |
-
# 4. Training Configuration
|
| 60 |
-
training_args = GRPOConfig(
|
| 61 |
-
output_dir = "./outputs",
|
| 62 |
-
learning_rate = 5e-6,
|
| 63 |
-
per_device_train_batch_size = 4,
|
| 64 |
-
gradient_accumulation_steps = 4,
|
| 65 |
-
max_prompt_length = 512,
|
| 66 |
-
max_completion_length = 512,
|
| 67 |
-
num_generations = 8, # Group size for GRPO
|
| 68 |
-
logging_steps = 1,
|
| 69 |
-
max_steps = 100, # Quick run for MVP
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
# 5. Initialize Trainer
|
| 73 |
-
trainer = GRPOTrainer(
|
| 74 |
-
model = model,
|
| 75 |
-
reward_funcs = [reward_function],
|
| 76 |
-
args = training_args,
|
| 77 |
-
train_dataset = [
|
| 78 |
-
{"prompt": "Write a function `sum_list(arr: list) -> int` that returns the sum of a list."}
|
| 79 |
-
] * 100, # Dummy dataset for V0 validation
|
| 80 |
-
)
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
|
|
|
| 90 |
class CurriculumManager:
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
self.window_size = 10 # Moving average window
|
| 96 |
|
| 97 |
-
def
|
| 98 |
return self.difficulties[self.current_idx]
|
| 99 |
|
| 100 |
-
def update(self,
|
| 101 |
-
self.success_history.append(
|
| 102 |
if len(self.success_history) > self.window_size:
|
| 103 |
self.success_history.pop(0)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
if avg_success > 0.70 and self.current_idx < len(self.difficulties) - 1:
|
| 108 |
self.current_idx += 1
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
] *
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
if __name__ == "__main__":
|
| 165 |
-
|
| 166 |
-
trainer.train()
|
| 167 |
-
>>>>>>> environment-v2
|
|
|
|
| 1 |
+
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
import argparse
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from env.adapt_env import AdaptEnvironment
|
| 8 |
+
from models import AdaptAction
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def extract_code(completion: str) -> str:
|
| 12 |
+
text = completion.strip()
|
| 13 |
+
if "```python" in text:
|
| 14 |
+
return text.split("```python", 1)[1].split("```", 1)[0].strip()
|
| 15 |
+
if "```" in text:
|
| 16 |
+
return text.split("```", 1)[1].split("```", 1)[0].strip()
|
| 17 |
+
return text
|
| 18 |
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
class CurriculumManager:
|
| 22 |
+
difficulties: list[str] = field(default_factory=lambda: ["easy", "medium", "hard"])
|
| 23 |
+
current_idx: int = 0
|
| 24 |
+
success_history: list[float] = field(default_factory=list)
|
| 25 |
+
window_size: int = 10
|
|
|
|
| 26 |
|
| 27 |
+
def current_difficulty(self) -> str:
|
| 28 |
return self.difficulties[self.current_idx]
|
| 29 |
|
| 30 |
+
def update(self, batch_success_rate: float) -> None:
|
| 31 |
+
self.success_history.append(float(batch_success_rate))
|
| 32 |
if len(self.success_history) > self.window_size:
|
| 33 |
self.success_history.pop(0)
|
| 34 |
+
|
| 35 |
+
moving_average = sum(self.success_history) / len(self.success_history)
|
| 36 |
+
if moving_average > 0.70 and self.current_idx < len(self.difficulties) - 1:
|
|
|
|
| 37 |
self.current_idx += 1
|
| 38 |
+
self.success_history.clear()
|
| 39 |
+
print(
|
| 40 |
+
f"[curriculum] promoted to {self.current_difficulty()} "
|
| 41 |
+
f"(moving_success={moving_average:.2f})"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def build_reward_func(curriculum: CurriculumManager):
|
| 46 |
+
def reward_func(prompts, completions, **kwargs) -> list[float]:
|
| 47 |
+
del prompts, kwargs
|
| 48 |
+
env = AdaptEnvironment()
|
| 49 |
+
rewards: list[float] = []
|
| 50 |
+
successes: list[float] = []
|
| 51 |
+
difficulty = curriculum.current_difficulty()
|
| 52 |
+
|
| 53 |
+
for completion in completions:
|
| 54 |
+
env.reset(difficulty=difficulty)
|
| 55 |
+
observation = env.step(AdaptAction(code=extract_code(completion)))
|
| 56 |
+
rewards.append(float(observation.reward))
|
| 57 |
+
successes.append(1.0 if observation.pass_rate == 1.0 else 0.0)
|
| 58 |
+
|
| 59 |
+
if successes:
|
| 60 |
+
curriculum.update(sum(successes) / len(successes))
|
| 61 |
+
|
| 62 |
+
return rewards
|
| 63 |
+
|
| 64 |
+
return reward_func
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def build_dataset(size: int) -> list[dict[str, str]]:
|
| 68 |
+
prompt = (
|
| 69 |
+
"Read the problem statement carefully. "
|
| 70 |
+
"Write a Python solution that reads from stdin and prints to stdout."
|
| 71 |
+
)
|
| 72 |
+
return [{"prompt": prompt}] * size
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def run_training(args: argparse.Namespace) -> None:
|
| 76 |
+
try:
|
| 77 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 78 |
+
from unsloth import FastLanguageModel, PatchFastRL
|
| 79 |
+
except ImportError as exc:
|
| 80 |
+
raise RuntimeError(
|
| 81 |
+
"Training dependencies are missing. Install `trl` and `unsloth` before running GRPO training."
|
| 82 |
+
) from exc
|
| 83 |
+
|
| 84 |
+
PatchFastRL("GRPO", FastLanguageModel)
|
| 85 |
+
|
| 86 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 87 |
+
model_name=args.model_name,
|
| 88 |
+
max_seq_length=args.max_seq_length,
|
| 89 |
+
load_in_4bit=not args.disable_4bit,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
model = FastLanguageModel.get_peft_model(
|
| 93 |
+
model,
|
| 94 |
+
r=args.lora_rank,
|
| 95 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 96 |
+
lora_alpha=args.lora_alpha,
|
| 97 |
+
lora_dropout=0.0,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
curriculum = CurriculumManager()
|
| 101 |
+
training_args = GRPOConfig(
|
| 102 |
+
output_dir=args.output_dir,
|
| 103 |
+
learning_rate=args.learning_rate,
|
| 104 |
+
per_device_train_batch_size=args.batch_size,
|
| 105 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 106 |
+
num_generations=args.num_generations,
|
| 107 |
+
max_prompt_length=args.max_prompt_length,
|
| 108 |
+
max_completion_length=args.max_completion_length,
|
| 109 |
+
max_steps=args.max_steps,
|
| 110 |
+
logging_steps=1,
|
| 111 |
+
bf16=args.bf16,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
trainer = GRPOTrainer(
|
| 115 |
+
model=model,
|
| 116 |
+
reward_funcs=[build_reward_func(curriculum)],
|
| 117 |
+
args=training_args,
|
| 118 |
+
train_dataset=build_dataset(args.dataset_size),
|
| 119 |
+
)
|
| 120 |
+
trainer.train()
|
| 121 |
+
model.save_pretrained(args.output_dir)
|
| 122 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 126 |
+
parser = argparse.ArgumentParser(description="GRPO training entrypoint for the ADAPT DSA environment.")
|
| 127 |
+
parser.add_argument("--model-name", default="unsloth/Llama-3.2-3B-Instruct")
|
| 128 |
+
parser.add_argument("--output-dir", default="outputs_v2")
|
| 129 |
+
parser.add_argument("--dataset-size", type=int, default=200)
|
| 130 |
+
parser.add_argument("--max-steps", type=int, default=250)
|
| 131 |
+
parser.add_argument("--batch-size", type=int, default=1)
|
| 132 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
|
| 133 |
+
parser.add_argument("--num-generations", type=int, default=8)
|
| 134 |
+
parser.add_argument("--max-seq-length", type=int, default=2048)
|
| 135 |
+
parser.add_argument("--max-prompt-length", type=int, default=512)
|
| 136 |
+
parser.add_argument("--max-completion-length", type=int, default=512)
|
| 137 |
+
parser.add_argument("--learning-rate", type=float, default=5e-6)
|
| 138 |
+
parser.add_argument("--lora-rank", type=int, default=16)
|
| 139 |
+
parser.add_argument("--lora-alpha", type=int, default=16)
|
| 140 |
+
parser.add_argument("--disable-4bit", action="store_true")
|
| 141 |
+
parser.add_argument("--bf16", action="store_true")
|
| 142 |
+
return parser
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def main(argv: list[str] | None = None) -> None:
|
| 146 |
+
parser = build_parser()
|
| 147 |
+
args = parser.parse_args(argv)
|
| 148 |
+
run_training(args)
|
| 149 |
+
|
| 150 |
|
| 151 |
if __name__ == "__main__":
|
| 152 |
+
main()
|
|
|
|
|
|
verifier/sandbox.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import subprocess
|
|
|
|
| 3 |
import tempfile
|
| 4 |
|
| 5 |
|
|
@@ -11,7 +12,7 @@ def run_code(code: str, stdin: str, timeout: int = 2):
|
|
| 11 |
path = f.name
|
| 12 |
|
| 13 |
result = subprocess.run(
|
| 14 |
-
[
|
| 15 |
input=stdin,
|
| 16 |
text=True,
|
| 17 |
capture_output=True,
|
|
@@ -31,4 +32,4 @@ def run_code(code: str, stdin: str, timeout: int = 2):
|
|
| 31 |
|
| 32 |
finally:
|
| 33 |
if path and os.path.exists(path):
|
| 34 |
-
os.remove(path)
|
|
|
|
| 1 |
import os
|
| 2 |
import subprocess
|
| 3 |
+
import sys
|
| 4 |
import tempfile
|
| 5 |
|
| 6 |
|
|
|
|
| 12 |
path = f.name
|
| 13 |
|
| 14 |
result = subprocess.run(
|
| 15 |
+
[sys.executable, path],
|
| 16 |
input=stdin,
|
| 17 |
text=True,
|
| 18 |
capture_output=True,
|
|
|
|
| 32 |
|
| 33 |
finally:
|
| 34 |
if path and os.path.exists(path):
|
| 35 |
+
os.remove(path)
|
verifier/verifier.py
CHANGED
|
@@ -5,7 +5,13 @@ from verifier.metrics import compute_pass_rate
|
|
| 5 |
def verify(code: str, test_cases):
|
| 6 |
results = []
|
| 7 |
|
| 8 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
ok, output = run_code(code, stdin)
|
| 10 |
|
| 11 |
passed = ok and output.strip() == expected.strip()
|
|
@@ -23,4 +29,4 @@ def verify(code: str, test_cases):
|
|
| 23 |
return reward, {
|
| 24 |
**metrics,
|
| 25 |
"results": results,
|
| 26 |
-
}
|
|
|
|
| 5 |
def verify(code: str, test_cases):
|
| 6 |
results = []
|
| 7 |
|
| 8 |
+
for test_case in test_cases:
|
| 9 |
+
if isinstance(test_case, dict):
|
| 10 |
+
stdin = str(test_case.get("input", ""))
|
| 11 |
+
expected = str(test_case.get("output", ""))
|
| 12 |
+
else:
|
| 13 |
+
stdin, expected = test_case
|
| 14 |
+
|
| 15 |
ok, output = run_code(code, stdin)
|
| 16 |
|
| 17 |
passed = ok and output.strip() == expected.strip()
|
|
|
|
| 29 |
return reward, {
|
| 30 |
**metrics,
|
| 31 |
"results": results,
|
| 32 |
+
}
|