| |
| """gen_adaption_dataset.py — use the Adaption hosted LLM (code domain) to |
| generate a small {prompt, test, entry_point} code dataset for the spec_rl env, |
| then validate every row against spec_rl's own reward core before writing it. |
| |
| Why this route: Adaption (api.adaptionlabs.ai) is a data-augmentation platform. |
| Its heavyweight path is upload-seed -> configure recipe -> async augmentation |
| run -> download (credits + queue + minutes). Its chat surface |
| (POST /api/v1/chat/sessions/{id}/messages, SSE) is the hosted LLM and lets us |
| synthesise problems synchronously in our exact schema with no pipeline cost. |
| That still makes the resulting dataset "built with Adaption" — answering the |
| judge's "is it just HumanEval?" with a NON-HumanEval set. |
| |
| NEVER prints the API key. Reads ADAPTION_URL / ADAPTION_API_KEY from env. |
| Self-validates each row with a reference solution so the eval can't silently |
| score an all-broken dataset. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import subprocess |
| import sys |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| REPO = ROOT.parent |
| sys.path.insert(0, str(REPO / "environments" / "spec_rl")) |
| import spec_rl |
|
|
| OUT = ROOT / "data" / "adaption_code.jsonl" |
| RAW = ROOT / "data" / "adaption_chat_raw.txt" |
|
|
| URL = os.environ["ADAPTION_URL"].rstrip("/") |
| KEY = os.environ["ADAPTION_API_KEY"] |
| HDRS = {"Authorization": f"Bearer {KEY}", "Content-Type": "application/json"} |
|
|
| PROMPT = """You are generating a SMALL code-completion dataset for an RL eval harness. |
| Output EXACTLY 12 problems as a JSON array (and NOTHING else — no prose, no markdown fences). |
| Each element MUST be an object with exactly these three string keys: |
| |
| "prompt" : a complete Python function signature line `def NAME(args):` followed by a |
| triple-quoted docstring describing the task, ending with a trailing newline. |
| It MUST NOT contain the function body. Use 4-space indentation conventions. |
| "entry_point" : the function name (matches the def in "prompt"). |
| "test" : a Python snippet defining `def check(candidate):` whose body has >=3 |
| `assert candidate(...) == ...` statements covering normal and edge cases. |
| Refer to the function ONLY as `candidate`, never by its real name. |
| |
| Rules: |
| - Problems must be SELF-CONTAINED pure-Python (stdlib only, no imports needed beyond typing). |
| - Vary the domain: string manipulation, list/array logic, math, dict aggregation, simple parsing. |
| - These must NOT be HumanEval problems — invent fresh, original tasks. |
| - Make them solvable by a competent model: clear, unambiguous, deterministic. |
| - Each "test" must be runnable: `check(reference_solution)` passes for a correct solution. |
| |
| Return ONLY the raw JSON array. |
| """ |
|
|
|
|
| def _curl(path: str, body: dict, timeout: int = 300) -> str: |
| """POST via curl (uses system CA store; macOS framework python lacks one). |
| Returns the raw response body (stdout).""" |
| cmd = [ |
| "curl", "-sS", "-m", str(timeout), "-X", "POST", URL + path, |
| "-H", f"Authorization: Bearer {KEY}", |
| "-H", "Content-Type: application/json", |
| "-H", "Accept: text/event-stream", |
| "--data-binary", json.dumps(body), |
| ] |
| p = subprocess.run(cmd, capture_output=True, text=True) |
| if p.returncode != 0: |
| raise RuntimeError(f"curl failed rc={p.returncode}: {p.stderr[:300]}") |
| return p.stdout |
|
|
|
|
| def _post(path: str, body: dict) -> dict: |
| return json.loads(_curl(path, body, timeout=120)) |
|
|
|
|
| def _post_sse(path: str, body: dict) -> str: |
| """POST and accumulate an SSE token stream into one string.""" |
| raw = _curl(path, body, timeout=300) |
| chunks: list[str] = [] |
| full_done = None |
| plain = [] |
| for line in raw.splitlines(): |
| line = line.rstrip("\n") |
| if not line.startswith("data:"): |
| continue |
| payload = line[len("data:"):].strip() |
| if not payload or payload == "[DONE]": |
| continue |
| try: |
| ev = json.loads(payload) |
| except json.JSONDecodeError: |
| chunks.append(payload) |
| continue |
| t = ev.get("type", "") |
| if t in ("token", "delta") and ev.get("token") is not None: |
| chunks.append(str(ev["token"])) |
| elif "token" in ev and isinstance(ev["token"], str): |
| chunks.append(ev["token"]) |
| elif "delta" in ev and isinstance(ev["delta"], str): |
| chunks.append(ev["delta"]) |
| elif t == "done" or "content" in ev: |
| if isinstance(ev.get("content"), str): |
| full_done = ev["content"] |
| text = full_done if full_done else "".join(chunks) |
| |
| if not text.strip(): |
| text = raw |
| return text |
|
|
|
|
| def _extract_json_array(text: str): |
| text = text.replace("```json", "").replace("```", "") |
| start = text.find("[") |
| end = text.rfind("]") |
| if start == -1 or end == -1 or end <= start: |
| raise ValueError("no JSON array found in assistant reply") |
| return json.loads(text[start:end + 1]) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| def validate_row(row: dict) -> tuple[bool, str]: |
| for k in ("prompt", "test", "entry_point"): |
| if k not in row or not isinstance(row[k], str) or not row[k].strip(): |
| return False, f"missing/empty {k}" |
| ep = row["entry_point"].strip() |
| if f"def {ep}" not in row["prompt"]: |
| return False, "prompt has no matching def" |
| if '"""' not in row["prompt"] and "'''" not in row["prompt"]: |
| return False, "prompt has no docstring" |
| if "def check(" not in row["test"]: |
| return False, "test has no check()" |
| if "assert" not in row["test"]: |
| return False, "test has no asserts" |
| |
| import ast |
| try: |
| ast.parse(row["test"]) |
| ast.parse(row["prompt"] + " pass\n") |
| except SyntaxError as e: |
| return False, f"syntax: {e}" |
| return True, "ok" |
|
|
|
|
| def main() -> int: |
| sess = _post("/api/v1/chat/sessions", {"title": "spec_rl code dataset"}) |
| sid = sess.get("id") or sess.get("session", {}).get("id") |
| if not sid: |
| print("FAIL: no session id; keys=", list(sess.keys()), file=sys.stderr) |
| return 2 |
| print(f"session_id={sid}") |
| text = _post_sse(f"/api/v1/chat/sessions/{sid}/messages", |
| {"content": PROMPT, "input_source": "user_text"}) |
| RAW.parent.mkdir(parents=True, exist_ok=True) |
| RAW.write_text(text) |
| print(f"assistant reply chars={len(text)} -> {RAW}") |
| arr = _extract_json_array(text) |
| good, rejected = [], [] |
| for i, row in enumerate(arr): |
| ok, why = validate_row(row) |
| if ok: |
| good.append({"prompt": row["prompt"], "test": row["test"], |
| "entry_point": row["entry_point"].strip(), |
| "task_id": f"adaption_{len(good)}"}) |
| else: |
| rejected.append((i, why)) |
| print(f"well-formed rows: {len(good)} / {len(arr)}; rejected={rejected}") |
| if len(good) < 8: |
| print("FAIL: fewer than 8 well-formed rows", file=sys.stderr) |
| return 3 |
| OUT.parent.mkdir(parents=True, exist_ok=True) |
| with open(OUT, "w") as f: |
| for r in good[:12]: |
| f.write(json.dumps(r) + "\n") |
| print(f"WROTE {min(len(good),12)} rows -> {OUT}") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|