File size: 9,637 Bytes
8cc969e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | #!/usr/bin/env python3
"""build_code_dataset.py — author + validate a 12-problem NON-HumanEval code
dataset in the exact spec_rl schema {prompt, test, entry_point}, then write it
to data/adaption_code.jsonl.
Validation is done with spec_rl's OWN reward core (fraction_passing): for each
problem we (a) confirm a known-correct reference solution scores 1.0, and (b)
confirm a deliberately-wrong solution scores < 1.0. This guarantees the eval
cannot silently run against an all-broken or trivially-passing dataset.
"""
from __future__ import annotations
import json, sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
REPO = ROOT.parent
sys.path.insert(0, str(REPO / "environments" / "spec_rl"))
import spec_rl
OUT = ROOT / "data" / "adaption_code.jsonl"
# Each entry: (prompt, test, entry_point, good_body, bad_body)
# prompt = signature + docstring (no body). test = check() with >=3 asserts.
PROBLEMS = [
(
'def running_total(nums):\n """Return a list where element i is the sum of nums[0..i] inclusive.\n running_total([1, 2, 3]) -> [1, 3, 6]; running_total([]) -> [].\n """\n',
"def check(candidate):\n assert candidate([1, 2, 3]) == [1, 3, 6]\n assert candidate([]) == []\n assert candidate([5]) == [5]\n assert candidate([-1, 1, -1]) == [-1, 0, -1]\n",
"running_total",
" out = []\n s = 0\n for n in nums:\n s += n\n out.append(s)\n return out\n",
" return nums\n",
),
(
'def count_vowels(s):\n """Return the number of vowels (a, e, i, o, u; case-insensitive) in s.\n count_vowels(\'Hello\') -> 2; count_vowels(\'\') -> 0.\n """\n',
"def check(candidate):\n assert candidate('Hello') == 2\n assert candidate('') == 0\n assert candidate('AEIOU') == 5\n assert candidate('xyz') == 0\n",
"count_vowels",
" return sum(1 for c in s.lower() if c in 'aeiou')\n",
" return len(s)\n",
),
(
'def merge_counts(a, b):\n """Given two dicts of key->int counts, return a new dict whose value per key\n is the sum of counts from a and b. Keys may appear in either dict.\n merge_counts({\'x\': 1}, {\'x\': 2, \'y\': 3}) -> {\'x\': 3, \'y\': 3}.\n """\n',
"def check(candidate):\n assert candidate({'x': 1}, {'x': 2, 'y': 3}) == {'x': 3, 'y': 3}\n assert candidate({}, {}) == {}\n assert candidate({'a': 5}, {}) == {'a': 5}\n assert candidate({}, {'b': 7}) == {'b': 7}\n",
"merge_counts",
" out = dict(a)\n for k, v in b.items():\n out[k] = out.get(k, 0) + v\n return out\n",
" return dict(a)\n",
),
(
'def second_largest(nums):\n """Return the second largest DISTINCT value in nums, or None if there are\n fewer than two distinct values.\n second_largest([4, 1, 4, 3]) -> 3; second_largest([7]) -> None.\n """\n',
"def check(candidate):\n assert candidate([4, 1, 4, 3]) == 3\n assert candidate([7]) is None\n assert candidate([5, 5, 5]) is None\n assert candidate([1, 2, 3, 4]) == 3\n assert candidate([-1, -2]) == -2\n",
"second_largest",
" u = sorted(set(nums), reverse=True)\n return u[1] if len(u) >= 2 else None\n",
" return max(nums)\n",
),
(
'def is_palindrome(s):\n """Return True if s is a palindrome ignoring case and non-alphanumeric chars.\n is_palindrome(\'A man, a plan, a canal: Panama\') -> True; is_palindrome(\'ab\') -> False.\n """\n',
"def check(candidate):\n assert candidate('A man, a plan, a canal: Panama') is True\n assert candidate('ab') is False\n assert candidate('') is True\n assert candidate('Racecar') is True\n assert candidate('No lemon, no melon') is True\n",
"is_palindrome",
" t = [c.lower() for c in s if c.isalnum()]\n return t == t[::-1]\n",
" return s == s[::-1]\n",
),
(
'def flatten(nested):\n """Flatten a list that may contain nested lists (one level or deep) into a\n single flat list, preserving order.\n flatten([1, [2, [3, 4]], 5]) -> [1, 2, 3, 4, 5].\n """\n',
"def check(candidate):\n assert candidate([1, [2, [3, 4]], 5]) == [1, 2, 3, 4, 5]\n assert candidate([]) == []\n assert candidate([[1], [2], [3]]) == [1, 2, 3]\n assert candidate([1, 2, 3]) == [1, 2, 3]\n",
"flatten",
" out = []\n for x in nested:\n if isinstance(x, list):\n out.extend(candidate_flatten(x))\n else:\n out.append(x)\n return out\ndef candidate_flatten(n):\n out = []\n for x in n:\n if isinstance(x, list):\n out.extend(candidate_flatten(x))\n else:\n out.append(x)\n return out\n",
" return nested\n",
),
(
'def word_frequencies(text):\n """Return a dict mapping each lowercased word to its count. Words are split on\n whitespace; punctuation is NOT stripped beyond lowercasing.\n word_frequencies(\'a a b\') -> {\'a\': 2, \'b\': 1}.\n """\n',
"def check(candidate):\n assert candidate('a a b') == {'a': 2, 'b': 1}\n assert candidate('') == {}\n assert candidate('Hi hi HI') == {'hi': 3}\n assert candidate('one') == {'one': 1}\n",
"word_frequencies",
" d = {}\n for w in text.lower().split():\n d[w] = d.get(w, 0) + 1\n return d\n",
" return {}\n",
),
(
'def chunk(seq, size):\n """Split list seq into consecutive chunks of length size (the last chunk may be\n shorter). size is a positive integer.\n chunk([1, 2, 3, 4, 5], 2) -> [[1, 2], [3, 4], [5]].\n """\n',
"def check(candidate):\n assert candidate([1, 2, 3, 4, 5], 2) == [[1, 2], [3, 4], [5]]\n assert candidate([], 3) == []\n assert candidate([1, 2, 3], 1) == [[1], [2], [3]]\n assert candidate([1, 2], 5) == [[1, 2]]\n",
"chunk",
" return [seq[i:i+size] for i in range(0, len(seq), size)]\n",
" return [seq]\n",
),
(
'def gcd(a, b):\n """Return the greatest common divisor of two non-negative integers a and b.\n gcd(12, 18) -> 6; gcd(7, 0) -> 7.\n """\n',
"def check(candidate):\n assert candidate(12, 18) == 6\n assert candidate(7, 0) == 7\n assert candidate(0, 5) == 5\n assert candidate(17, 13) == 1\n assert candidate(100, 80) == 20\n",
"gcd",
" while b:\n a, b = b, a % b\n return a\n",
" return a\n",
),
(
'def title_case(s):\n """Return s with the first letter of each whitespace-separated word uppercased\n and the rest lowercased.\n title_case(\'hELLO wORLD\') -> \'Hello World\'.\n """\n',
"def check(candidate):\n assert candidate('hELLO wORLD') == 'Hello World'\n assert candidate('') == ''\n assert candidate('a') == 'A'\n assert candidate('the QUICK brown') == 'The Quick Brown'\n",
"title_case",
" return ' '.join(w[:1].upper() + w[1:].lower() for w in s.split())\n",
" return s.upper()\n",
),
(
'def dedupe_preserve_order(items):\n """Return a list with duplicates removed, keeping the FIRST occurrence order.\n dedupe_preserve_order([3, 1, 3, 2, 1]) -> [3, 1, 2].\n """\n',
"def check(candidate):\n assert candidate([3, 1, 3, 2, 1]) == [3, 1, 2]\n assert candidate([]) == []\n assert candidate([1, 1, 1]) == [1]\n assert candidate(['a', 'b', 'a']) == ['a', 'b']\n",
"dedupe_preserve_order",
" seen = set()\n out = []\n for x in items:\n if x not in seen:\n seen.add(x)\n out.append(x)\n return out\n",
" return list(set(items))\n",
),
(
'def roman_to_int(s):\n """Convert a Roman numeral string (I, V, X, L, C, D, M; valid, uppercase) to an int.\n roman_to_int(\'IV\') -> 4; roman_to_int(\'XIV\') -> 14; roman_to_int(\'MCMXCIV\') -> 1994.\n """\n',
"def check(candidate):\n assert candidate('IV') == 4\n assert candidate('XIV') == 14\n assert candidate('MCMXCIV') == 1994\n assert candidate('III') == 3\n assert candidate('LVIII') == 58\n",
"roman_to_int",
" vals = {'I':1,'V':5,'X':10,'L':50,'C':100,'D':500,'M':1000}\n total = 0\n prev = 0\n for c in reversed(s):\n v = vals[c]\n if v < prev:\n total -= v\n else:\n total += v\n prev = v\n return total\n",
" return sum({'I':1,'V':5,'X':10,'L':50,'C':100,'D':500,'M':1000}[c] for c in s)\n",
),
]
def main() -> int:
rows = []
failures = []
for i, (prompt, test, ep, good, bad) in enumerate(PROBLEMS):
prob = {"prompt": prompt, "test": test, "entry_point": ep}
good_score = spec_rl.fraction_passing(prob, good)
bad_score = spec_rl.fraction_passing(prob, bad)
ok = (good_score == 1.0) and (bad_score < 1.0)
status = "OK" if ok else "BAD"
print(f"[{status}] {ep:24} good={good_score:.3f} bad={bad_score:.3f}")
if not ok:
failures.append((ep, good_score, bad_score))
rows.append({**prob, "task_id": f"adaption_{i}"})
if failures:
print("VALIDATION FAILURES:", failures, file=sys.stderr)
return 1
OUT.parent.mkdir(parents=True, exist_ok=True)
with open(OUT, "w") as f:
for r in rows:
f.write(json.dumps(r) + "\n")
print(f"\nWROTE {len(rows)} validated rows -> {OUT}")
return 0
if __name__ == "__main__":
raise SystemExit(main())
|