lean-laguna / scripts /build_code_dataset.py
art87able's picture
Upload folder using huggingface_hub
8cc969e verified
#!/usr/bin/env python3
"""build_code_dataset.py — author + validate a 12-problem NON-HumanEval code
dataset in the exact spec_rl schema {prompt, test, entry_point}, then write it
to data/adaption_code.jsonl.
Validation is done with spec_rl's OWN reward core (fraction_passing): for each
problem we (a) confirm a known-correct reference solution scores 1.0, and (b)
confirm a deliberately-wrong solution scores < 1.0. This guarantees the eval
cannot silently run against an all-broken or trivially-passing dataset.
"""
from __future__ import annotations
import json, sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
REPO = ROOT.parent
sys.path.insert(0, str(REPO / "environments" / "spec_rl"))
import spec_rl
OUT = ROOT / "data" / "adaption_code.jsonl"
# Each entry: (prompt, test, entry_point, good_body, bad_body)
# prompt = signature + docstring (no body). test = check() with >=3 asserts.
PROBLEMS = [
(
'def running_total(nums):\n """Return a list where element i is the sum of nums[0..i] inclusive.\n running_total([1, 2, 3]) -> [1, 3, 6]; running_total([]) -> [].\n """\n',
"def check(candidate):\n assert candidate([1, 2, 3]) == [1, 3, 6]\n assert candidate([]) == []\n assert candidate([5]) == [5]\n assert candidate([-1, 1, -1]) == [-1, 0, -1]\n",
"running_total",
" out = []\n s = 0\n for n in nums:\n s += n\n out.append(s)\n return out\n",
" return nums\n",
),
(
'def count_vowels(s):\n """Return the number of vowels (a, e, i, o, u; case-insensitive) in s.\n count_vowels(\'Hello\') -> 2; count_vowels(\'\') -> 0.\n """\n',
"def check(candidate):\n assert candidate('Hello') == 2\n assert candidate('') == 0\n assert candidate('AEIOU') == 5\n assert candidate('xyz') == 0\n",
"count_vowels",
" return sum(1 for c in s.lower() if c in 'aeiou')\n",
" return len(s)\n",
),
(
'def merge_counts(a, b):\n """Given two dicts of key->int counts, return a new dict whose value per key\n is the sum of counts from a and b. Keys may appear in either dict.\n merge_counts({\'x\': 1}, {\'x\': 2, \'y\': 3}) -> {\'x\': 3, \'y\': 3}.\n """\n',
"def check(candidate):\n assert candidate({'x': 1}, {'x': 2, 'y': 3}) == {'x': 3, 'y': 3}\n assert candidate({}, {}) == {}\n assert candidate({'a': 5}, {}) == {'a': 5}\n assert candidate({}, {'b': 7}) == {'b': 7}\n",
"merge_counts",
" out = dict(a)\n for k, v in b.items():\n out[k] = out.get(k, 0) + v\n return out\n",
" return dict(a)\n",
),
(
'def second_largest(nums):\n """Return the second largest DISTINCT value in nums, or None if there are\n fewer than two distinct values.\n second_largest([4, 1, 4, 3]) -> 3; second_largest([7]) -> None.\n """\n',
"def check(candidate):\n assert candidate([4, 1, 4, 3]) == 3\n assert candidate([7]) is None\n assert candidate([5, 5, 5]) is None\n assert candidate([1, 2, 3, 4]) == 3\n assert candidate([-1, -2]) == -2\n",
"second_largest",
" u = sorted(set(nums), reverse=True)\n return u[1] if len(u) >= 2 else None\n",
" return max(nums)\n",
),
(
'def is_palindrome(s):\n """Return True if s is a palindrome ignoring case and non-alphanumeric chars.\n is_palindrome(\'A man, a plan, a canal: Panama\') -> True; is_palindrome(\'ab\') -> False.\n """\n',
"def check(candidate):\n assert candidate('A man, a plan, a canal: Panama') is True\n assert candidate('ab') is False\n assert candidate('') is True\n assert candidate('Racecar') is True\n assert candidate('No lemon, no melon') is True\n",
"is_palindrome",
" t = [c.lower() for c in s if c.isalnum()]\n return t == t[::-1]\n",
" return s == s[::-1]\n",
),
(
'def flatten(nested):\n """Flatten a list that may contain nested lists (one level or deep) into a\n single flat list, preserving order.\n flatten([1, [2, [3, 4]], 5]) -> [1, 2, 3, 4, 5].\n """\n',
"def check(candidate):\n assert candidate([1, [2, [3, 4]], 5]) == [1, 2, 3, 4, 5]\n assert candidate([]) == []\n assert candidate([[1], [2], [3]]) == [1, 2, 3]\n assert candidate([1, 2, 3]) == [1, 2, 3]\n",
"flatten",
" out = []\n for x in nested:\n if isinstance(x, list):\n out.extend(candidate_flatten(x))\n else:\n out.append(x)\n return out\ndef candidate_flatten(n):\n out = []\n for x in n:\n if isinstance(x, list):\n out.extend(candidate_flatten(x))\n else:\n out.append(x)\n return out\n",
" return nested\n",
),
(
'def word_frequencies(text):\n """Return a dict mapping each lowercased word to its count. Words are split on\n whitespace; punctuation is NOT stripped beyond lowercasing.\n word_frequencies(\'a a b\') -> {\'a\': 2, \'b\': 1}.\n """\n',
"def check(candidate):\n assert candidate('a a b') == {'a': 2, 'b': 1}\n assert candidate('') == {}\n assert candidate('Hi hi HI') == {'hi': 3}\n assert candidate('one') == {'one': 1}\n",
"word_frequencies",
" d = {}\n for w in text.lower().split():\n d[w] = d.get(w, 0) + 1\n return d\n",
" return {}\n",
),
(
'def chunk(seq, size):\n """Split list seq into consecutive chunks of length size (the last chunk may be\n shorter). size is a positive integer.\n chunk([1, 2, 3, 4, 5], 2) -> [[1, 2], [3, 4], [5]].\n """\n',
"def check(candidate):\n assert candidate([1, 2, 3, 4, 5], 2) == [[1, 2], [3, 4], [5]]\n assert candidate([], 3) == []\n assert candidate([1, 2, 3], 1) == [[1], [2], [3]]\n assert candidate([1, 2], 5) == [[1, 2]]\n",
"chunk",
" return [seq[i:i+size] for i in range(0, len(seq), size)]\n",
" return [seq]\n",
),
(
'def gcd(a, b):\n """Return the greatest common divisor of two non-negative integers a and b.\n gcd(12, 18) -> 6; gcd(7, 0) -> 7.\n """\n',
"def check(candidate):\n assert candidate(12, 18) == 6\n assert candidate(7, 0) == 7\n assert candidate(0, 5) == 5\n assert candidate(17, 13) == 1\n assert candidate(100, 80) == 20\n",
"gcd",
" while b:\n a, b = b, a % b\n return a\n",
" return a\n",
),
(
'def title_case(s):\n """Return s with the first letter of each whitespace-separated word uppercased\n and the rest lowercased.\n title_case(\'hELLO wORLD\') -> \'Hello World\'.\n """\n',
"def check(candidate):\n assert candidate('hELLO wORLD') == 'Hello World'\n assert candidate('') == ''\n assert candidate('a') == 'A'\n assert candidate('the QUICK brown') == 'The Quick Brown'\n",
"title_case",
" return ' '.join(w[:1].upper() + w[1:].lower() for w in s.split())\n",
" return s.upper()\n",
),
(
'def dedupe_preserve_order(items):\n """Return a list with duplicates removed, keeping the FIRST occurrence order.\n dedupe_preserve_order([3, 1, 3, 2, 1]) -> [3, 1, 2].\n """\n',
"def check(candidate):\n assert candidate([3, 1, 3, 2, 1]) == [3, 1, 2]\n assert candidate([]) == []\n assert candidate([1, 1, 1]) == [1]\n assert candidate(['a', 'b', 'a']) == ['a', 'b']\n",
"dedupe_preserve_order",
" seen = set()\n out = []\n for x in items:\n if x not in seen:\n seen.add(x)\n out.append(x)\n return out\n",
" return list(set(items))\n",
),
(
'def roman_to_int(s):\n """Convert a Roman numeral string (I, V, X, L, C, D, M; valid, uppercase) to an int.\n roman_to_int(\'IV\') -> 4; roman_to_int(\'XIV\') -> 14; roman_to_int(\'MCMXCIV\') -> 1994.\n """\n',
"def check(candidate):\n assert candidate('IV') == 4\n assert candidate('XIV') == 14\n assert candidate('MCMXCIV') == 1994\n assert candidate('III') == 3\n assert candidate('LVIII') == 58\n",
"roman_to_int",
" vals = {'I':1,'V':5,'X':10,'L':50,'C':100,'D':500,'M':1000}\n total = 0\n prev = 0\n for c in reversed(s):\n v = vals[c]\n if v < prev:\n total -= v\n else:\n total += v\n prev = v\n return total\n",
" return sum({'I':1,'V':5,'X':10,'L':50,'C':100,'D':500,'M':1000}[c] for c in s)\n",
),
]
def main() -> int:
rows = []
failures = []
for i, (prompt, test, ep, good, bad) in enumerate(PROBLEMS):
prob = {"prompt": prompt, "test": test, "entry_point": ep}
good_score = spec_rl.fraction_passing(prob, good)
bad_score = spec_rl.fraction_passing(prob, bad)
ok = (good_score == 1.0) and (bad_score < 1.0)
status = "OK" if ok else "BAD"
print(f"[{status}] {ep:24} good={good_score:.3f} bad={bad_score:.3f}")
if not ok:
failures.append((ep, good_score, bad_score))
rows.append({**prob, "task_id": f"adaption_{i}"})
if failures:
print("VALIDATION FAILURES:", failures, file=sys.stderr)
return 1
OUT.parent.mkdir(parents=True, exist_ok=True)
with open(OUT, "w") as f:
for r in rows:
f.write(json.dumps(r) + "\n")
print(f"\nWROTE {len(rows)} validated rows -> {OUT}")
return 0
if __name__ == "__main__":
raise SystemExit(main())