#!/usr/bin/env python3 """build_code_dataset.py — author + validate a 12-problem NON-HumanEval code dataset in the exact spec_rl schema {prompt, test, entry_point}, then write it to data/adaption_code.jsonl. Validation is done with spec_rl's OWN reward core (fraction_passing): for each problem we (a) confirm a known-correct reference solution scores 1.0, and (b) confirm a deliberately-wrong solution scores < 1.0. This guarantees the eval cannot silently run against an all-broken or trivially-passing dataset. """ from __future__ import annotations import json, sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] REPO = ROOT.parent sys.path.insert(0, str(REPO / "environments" / "spec_rl")) import spec_rl OUT = ROOT / "data" / "adaption_code.jsonl" # Each entry: (prompt, test, entry_point, good_body, bad_body) # prompt = signature + docstring (no body). test = check() with >=3 asserts. PROBLEMS = [ ( 'def running_total(nums):\n """Return a list where element i is the sum of nums[0..i] inclusive.\n running_total([1, 2, 3]) -> [1, 3, 6]; running_total([]) -> [].\n """\n', "def check(candidate):\n assert candidate([1, 2, 3]) == [1, 3, 6]\n assert candidate([]) == []\n assert candidate([5]) == [5]\n assert candidate([-1, 1, -1]) == [-1, 0, -1]\n", "running_total", " out = []\n s = 0\n for n in nums:\n s += n\n out.append(s)\n return out\n", " return nums\n", ), ( 'def count_vowels(s):\n """Return the number of vowels (a, e, i, o, u; case-insensitive) in s.\n count_vowels(\'Hello\') -> 2; count_vowels(\'\') -> 0.\n """\n', "def check(candidate):\n assert candidate('Hello') == 2\n assert candidate('') == 0\n assert candidate('AEIOU') == 5\n assert candidate('xyz') == 0\n", "count_vowels", " return sum(1 for c in s.lower() if c in 'aeiou')\n", " return len(s)\n", ), ( 'def merge_counts(a, b):\n """Given two dicts of key->int counts, return a new dict whose value per key\n is the sum of counts from a and b. Keys may appear in either dict.\n merge_counts({\'x\': 1}, {\'x\': 2, \'y\': 3}) -> {\'x\': 3, \'y\': 3}.\n """\n', "def check(candidate):\n assert candidate({'x': 1}, {'x': 2, 'y': 3}) == {'x': 3, 'y': 3}\n assert candidate({}, {}) == {}\n assert candidate({'a': 5}, {}) == {'a': 5}\n assert candidate({}, {'b': 7}) == {'b': 7}\n", "merge_counts", " out = dict(a)\n for k, v in b.items():\n out[k] = out.get(k, 0) + v\n return out\n", " return dict(a)\n", ), ( 'def second_largest(nums):\n """Return the second largest DISTINCT value in nums, or None if there are\n fewer than two distinct values.\n second_largest([4, 1, 4, 3]) -> 3; second_largest([7]) -> None.\n """\n', "def check(candidate):\n assert candidate([4, 1, 4, 3]) == 3\n assert candidate([7]) is None\n assert candidate([5, 5, 5]) is None\n assert candidate([1, 2, 3, 4]) == 3\n assert candidate([-1, -2]) == -2\n", "second_largest", " u = sorted(set(nums), reverse=True)\n return u[1] if len(u) >= 2 else None\n", " return max(nums)\n", ), ( 'def is_palindrome(s):\n """Return True if s is a palindrome ignoring case and non-alphanumeric chars.\n is_palindrome(\'A man, a plan, a canal: Panama\') -> True; is_palindrome(\'ab\') -> False.\n """\n', "def check(candidate):\n assert candidate('A man, a plan, a canal: Panama') is True\n assert candidate('ab') is False\n assert candidate('') is True\n assert candidate('Racecar') is True\n assert candidate('No lemon, no melon') is True\n", "is_palindrome", " t = [c.lower() for c in s if c.isalnum()]\n return t == t[::-1]\n", " return s == s[::-1]\n", ), ( 'def flatten(nested):\n """Flatten a list that may contain nested lists (one level or deep) into a\n single flat list, preserving order.\n flatten([1, [2, [3, 4]], 5]) -> [1, 2, 3, 4, 5].\n """\n', "def check(candidate):\n assert candidate([1, [2, [3, 4]], 5]) == [1, 2, 3, 4, 5]\n assert candidate([]) == []\n assert candidate([[1], [2], [3]]) == [1, 2, 3]\n assert candidate([1, 2, 3]) == [1, 2, 3]\n", "flatten", " out = []\n for x in nested:\n if isinstance(x, list):\n out.extend(candidate_flatten(x))\n else:\n out.append(x)\n return out\ndef candidate_flatten(n):\n out = []\n for x in n:\n if isinstance(x, list):\n out.extend(candidate_flatten(x))\n else:\n out.append(x)\n return out\n", " return nested\n", ), ( 'def word_frequencies(text):\n """Return a dict mapping each lowercased word to its count. Words are split on\n whitespace; punctuation is NOT stripped beyond lowercasing.\n word_frequencies(\'a a b\') -> {\'a\': 2, \'b\': 1}.\n """\n', "def check(candidate):\n assert candidate('a a b') == {'a': 2, 'b': 1}\n assert candidate('') == {}\n assert candidate('Hi hi HI') == {'hi': 3}\n assert candidate('one') == {'one': 1}\n", "word_frequencies", " d = {}\n for w in text.lower().split():\n d[w] = d.get(w, 0) + 1\n return d\n", " return {}\n", ), ( 'def chunk(seq, size):\n """Split list seq into consecutive chunks of length size (the last chunk may be\n shorter). size is a positive integer.\n chunk([1, 2, 3, 4, 5], 2) -> [[1, 2], [3, 4], [5]].\n """\n', "def check(candidate):\n assert candidate([1, 2, 3, 4, 5], 2) == [[1, 2], [3, 4], [5]]\n assert candidate([], 3) == []\n assert candidate([1, 2, 3], 1) == [[1], [2], [3]]\n assert candidate([1, 2], 5) == [[1, 2]]\n", "chunk", " return [seq[i:i+size] for i in range(0, len(seq), size)]\n", " return [seq]\n", ), ( 'def gcd(a, b):\n """Return the greatest common divisor of two non-negative integers a and b.\n gcd(12, 18) -> 6; gcd(7, 0) -> 7.\n """\n', "def check(candidate):\n assert candidate(12, 18) == 6\n assert candidate(7, 0) == 7\n assert candidate(0, 5) == 5\n assert candidate(17, 13) == 1\n assert candidate(100, 80) == 20\n", "gcd", " while b:\n a, b = b, a % b\n return a\n", " return a\n", ), ( 'def title_case(s):\n """Return s with the first letter of each whitespace-separated word uppercased\n and the rest lowercased.\n title_case(\'hELLO wORLD\') -> \'Hello World\'.\n """\n', "def check(candidate):\n assert candidate('hELLO wORLD') == 'Hello World'\n assert candidate('') == ''\n assert candidate('a') == 'A'\n assert candidate('the QUICK brown') == 'The Quick Brown'\n", "title_case", " return ' '.join(w[:1].upper() + w[1:].lower() for w in s.split())\n", " return s.upper()\n", ), ( 'def dedupe_preserve_order(items):\n """Return a list with duplicates removed, keeping the FIRST occurrence order.\n dedupe_preserve_order([3, 1, 3, 2, 1]) -> [3, 1, 2].\n """\n', "def check(candidate):\n assert candidate([3, 1, 3, 2, 1]) == [3, 1, 2]\n assert candidate([]) == []\n assert candidate([1, 1, 1]) == [1]\n assert candidate(['a', 'b', 'a']) == ['a', 'b']\n", "dedupe_preserve_order", " seen = set()\n out = []\n for x in items:\n if x not in seen:\n seen.add(x)\n out.append(x)\n return out\n", " return list(set(items))\n", ), ( 'def roman_to_int(s):\n """Convert a Roman numeral string (I, V, X, L, C, D, M; valid, uppercase) to an int.\n roman_to_int(\'IV\') -> 4; roman_to_int(\'XIV\') -> 14; roman_to_int(\'MCMXCIV\') -> 1994.\n """\n', "def check(candidate):\n assert candidate('IV') == 4\n assert candidate('XIV') == 14\n assert candidate('MCMXCIV') == 1994\n assert candidate('III') == 3\n assert candidate('LVIII') == 58\n", "roman_to_int", " vals = {'I':1,'V':5,'X':10,'L':50,'C':100,'D':500,'M':1000}\n total = 0\n prev = 0\n for c in reversed(s):\n v = vals[c]\n if v < prev:\n total -= v\n else:\n total += v\n prev = v\n return total\n", " return sum({'I':1,'V':5,'X':10,'L':50,'C':100,'D':500,'M':1000}[c] for c in s)\n", ), ] def main() -> int: rows = [] failures = [] for i, (prompt, test, ep, good, bad) in enumerate(PROBLEMS): prob = {"prompt": prompt, "test": test, "entry_point": ep} good_score = spec_rl.fraction_passing(prob, good) bad_score = spec_rl.fraction_passing(prob, bad) ok = (good_score == 1.0) and (bad_score < 1.0) status = "OK" if ok else "BAD" print(f"[{status}] {ep:24} good={good_score:.3f} bad={bad_score:.3f}") if not ok: failures.append((ep, good_score, bad_score)) rows.append({**prob, "task_id": f"adaption_{i}"}) if failures: print("VALIDATION FAILURES:", failures, file=sys.stderr) return 1 OUT.parent.mkdir(parents=True, exist_ok=True) with open(OUT, "w") as f: for r in rows: f.write(json.dumps(r) + "\n") print(f"\nWROTE {len(rows)} validated rows -> {OUT}") return 0 if __name__ == "__main__": raise SystemExit(main())