File size: 9,637 Bytes
8cc969e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python3
"""build_code_dataset.py — author + validate a 12-problem NON-HumanEval code
dataset in the exact spec_rl schema {prompt, test, entry_point}, then write it
to data/adaption_code.jsonl.

Validation is done with spec_rl's OWN reward core (fraction_passing): for each
problem we (a) confirm a known-correct reference solution scores 1.0, and (b)
confirm a deliberately-wrong solution scores < 1.0. This guarantees the eval
cannot silently run against an all-broken or trivially-passing dataset.
"""
from __future__ import annotations
import json, sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
REPO = ROOT.parent
sys.path.insert(0, str(REPO / "environments" / "spec_rl"))
import spec_rl

OUT = ROOT / "data" / "adaption_code.jsonl"

# Each entry: (prompt, test, entry_point, good_body, bad_body)
# prompt = signature + docstring (no body). test = check() with >=3 asserts.
PROBLEMS = [
    (
        'def running_total(nums):\n    """Return a list where element i is the sum of nums[0..i] inclusive.\n    running_total([1, 2, 3]) -> [1, 3, 6]; running_total([]) -> [].\n    """\n',
        "def check(candidate):\n    assert candidate([1, 2, 3]) == [1, 3, 6]\n    assert candidate([]) == []\n    assert candidate([5]) == [5]\n    assert candidate([-1, 1, -1]) == [-1, 0, -1]\n",
        "running_total",
        "    out = []\n    s = 0\n    for n in nums:\n        s += n\n        out.append(s)\n    return out\n",
        "    return nums\n",
    ),
    (
        'def count_vowels(s):\n    """Return the number of vowels (a, e, i, o, u; case-insensitive) in s.\n    count_vowels(\'Hello\') -> 2; count_vowels(\'\') -> 0.\n    """\n',
        "def check(candidate):\n    assert candidate('Hello') == 2\n    assert candidate('') == 0\n    assert candidate('AEIOU') == 5\n    assert candidate('xyz') == 0\n",
        "count_vowels",
        "    return sum(1 for c in s.lower() if c in 'aeiou')\n",
        "    return len(s)\n",
    ),
    (
        'def merge_counts(a, b):\n    """Given two dicts of key->int counts, return a new dict whose value per key\n    is the sum of counts from a and b. Keys may appear in either dict.\n    merge_counts({\'x\': 1}, {\'x\': 2, \'y\': 3}) -> {\'x\': 3, \'y\': 3}.\n    """\n',
        "def check(candidate):\n    assert candidate({'x': 1}, {'x': 2, 'y': 3}) == {'x': 3, 'y': 3}\n    assert candidate({}, {}) == {}\n    assert candidate({'a': 5}, {}) == {'a': 5}\n    assert candidate({}, {'b': 7}) == {'b': 7}\n",
        "merge_counts",
        "    out = dict(a)\n    for k, v in b.items():\n        out[k] = out.get(k, 0) + v\n    return out\n",
        "    return dict(a)\n",
    ),
    (
        'def second_largest(nums):\n    """Return the second largest DISTINCT value in nums, or None if there are\n    fewer than two distinct values.\n    second_largest([4, 1, 4, 3]) -> 3; second_largest([7]) -> None.\n    """\n',
        "def check(candidate):\n    assert candidate([4, 1, 4, 3]) == 3\n    assert candidate([7]) is None\n    assert candidate([5, 5, 5]) is None\n    assert candidate([1, 2, 3, 4]) == 3\n    assert candidate([-1, -2]) == -2\n",
        "second_largest",
        "    u = sorted(set(nums), reverse=True)\n    return u[1] if len(u) >= 2 else None\n",
        "    return max(nums)\n",
    ),
    (
        'def is_palindrome(s):\n    """Return True if s is a palindrome ignoring case and non-alphanumeric chars.\n    is_palindrome(\'A man, a plan, a canal: Panama\') -> True; is_palindrome(\'ab\') -> False.\n    """\n',
        "def check(candidate):\n    assert candidate('A man, a plan, a canal: Panama') is True\n    assert candidate('ab') is False\n    assert candidate('') is True\n    assert candidate('Racecar') is True\n    assert candidate('No lemon, no melon') is True\n",
        "is_palindrome",
        "    t = [c.lower() for c in s if c.isalnum()]\n    return t == t[::-1]\n",
        "    return s == s[::-1]\n",
    ),
    (
        'def flatten(nested):\n    """Flatten a list that may contain nested lists (one level or deep) into a\n    single flat list, preserving order.\n    flatten([1, [2, [3, 4]], 5]) -> [1, 2, 3, 4, 5].\n    """\n',
        "def check(candidate):\n    assert candidate([1, [2, [3, 4]], 5]) == [1, 2, 3, 4, 5]\n    assert candidate([]) == []\n    assert candidate([[1], [2], [3]]) == [1, 2, 3]\n    assert candidate([1, 2, 3]) == [1, 2, 3]\n",
        "flatten",
        "    out = []\n    for x in nested:\n        if isinstance(x, list):\n            out.extend(candidate_flatten(x))\n        else:\n            out.append(x)\n    return out\ndef candidate_flatten(n):\n    out = []\n    for x in n:\n        if isinstance(x, list):\n            out.extend(candidate_flatten(x))\n        else:\n            out.append(x)\n    return out\n",
        "    return nested\n",
    ),
    (
        'def word_frequencies(text):\n    """Return a dict mapping each lowercased word to its count. Words are split on\n    whitespace; punctuation is NOT stripped beyond lowercasing.\n    word_frequencies(\'a a b\') -> {\'a\': 2, \'b\': 1}.\n    """\n',
        "def check(candidate):\n    assert candidate('a a b') == {'a': 2, 'b': 1}\n    assert candidate('') == {}\n    assert candidate('Hi hi HI') == {'hi': 3}\n    assert candidate('one') == {'one': 1}\n",
        "word_frequencies",
        "    d = {}\n    for w in text.lower().split():\n        d[w] = d.get(w, 0) + 1\n    return d\n",
        "    return {}\n",
    ),
    (
        'def chunk(seq, size):\n    """Split list seq into consecutive chunks of length size (the last chunk may be\n    shorter). size is a positive integer.\n    chunk([1, 2, 3, 4, 5], 2) -> [[1, 2], [3, 4], [5]].\n    """\n',
        "def check(candidate):\n    assert candidate([1, 2, 3, 4, 5], 2) == [[1, 2], [3, 4], [5]]\n    assert candidate([], 3) == []\n    assert candidate([1, 2, 3], 1) == [[1], [2], [3]]\n    assert candidate([1, 2], 5) == [[1, 2]]\n",
        "chunk",
        "    return [seq[i:i+size] for i in range(0, len(seq), size)]\n",
        "    return [seq]\n",
    ),
    (
        'def gcd(a, b):\n    """Return the greatest common divisor of two non-negative integers a and b.\n    gcd(12, 18) -> 6; gcd(7, 0) -> 7.\n    """\n',
        "def check(candidate):\n    assert candidate(12, 18) == 6\n    assert candidate(7, 0) == 7\n    assert candidate(0, 5) == 5\n    assert candidate(17, 13) == 1\n    assert candidate(100, 80) == 20\n",
        "gcd",
        "    while b:\n        a, b = b, a % b\n    return a\n",
        "    return a\n",
    ),
    (
        'def title_case(s):\n    """Return s with the first letter of each whitespace-separated word uppercased\n    and the rest lowercased.\n    title_case(\'hELLO wORLD\') -> \'Hello World\'.\n    """\n',
        "def check(candidate):\n    assert candidate('hELLO wORLD') == 'Hello World'\n    assert candidate('') == ''\n    assert candidate('a') == 'A'\n    assert candidate('the QUICK brown') == 'The Quick Brown'\n",
        "title_case",
        "    return ' '.join(w[:1].upper() + w[1:].lower() for w in s.split())\n",
        "    return s.upper()\n",
    ),
    (
        'def dedupe_preserve_order(items):\n    """Return a list with duplicates removed, keeping the FIRST occurrence order.\n    dedupe_preserve_order([3, 1, 3, 2, 1]) -> [3, 1, 2].\n    """\n',
        "def check(candidate):\n    assert candidate([3, 1, 3, 2, 1]) == [3, 1, 2]\n    assert candidate([]) == []\n    assert candidate([1, 1, 1]) == [1]\n    assert candidate(['a', 'b', 'a']) == ['a', 'b']\n",
        "dedupe_preserve_order",
        "    seen = set()\n    out = []\n    for x in items:\n        if x not in seen:\n            seen.add(x)\n            out.append(x)\n    return out\n",
        "    return list(set(items))\n",
    ),
    (
        'def roman_to_int(s):\n    """Convert a Roman numeral string (I, V, X, L, C, D, M; valid, uppercase) to an int.\n    roman_to_int(\'IV\') -> 4; roman_to_int(\'XIV\') -> 14; roman_to_int(\'MCMXCIV\') -> 1994.\n    """\n',
        "def check(candidate):\n    assert candidate('IV') == 4\n    assert candidate('XIV') == 14\n    assert candidate('MCMXCIV') == 1994\n    assert candidate('III') == 3\n    assert candidate('LVIII') == 58\n",
        "roman_to_int",
        "    vals = {'I':1,'V':5,'X':10,'L':50,'C':100,'D':500,'M':1000}\n    total = 0\n    prev = 0\n    for c in reversed(s):\n        v = vals[c]\n        if v < prev:\n            total -= v\n        else:\n            total += v\n            prev = v\n    return total\n",
        "    return sum({'I':1,'V':5,'X':10,'L':50,'C':100,'D':500,'M':1000}[c] for c in s)\n",
    ),
]


def main() -> int:
    rows = []
    failures = []
    for i, (prompt, test, ep, good, bad) in enumerate(PROBLEMS):
        prob = {"prompt": prompt, "test": test, "entry_point": ep}
        good_score = spec_rl.fraction_passing(prob, good)
        bad_score = spec_rl.fraction_passing(prob, bad)
        ok = (good_score == 1.0) and (bad_score < 1.0)
        status = "OK" if ok else "BAD"
        print(f"[{status}] {ep:24} good={good_score:.3f} bad={bad_score:.3f}")
        if not ok:
            failures.append((ep, good_score, bad_score))
        rows.append({**prob, "task_id": f"adaption_{i}"})
    if failures:
        print("VALIDATION FAILURES:", failures, file=sys.stderr)
        return 1
    OUT.parent.mkdir(parents=True, exist_ok=True)
    with open(OUT, "w") as f:
        for r in rows:
            f.write(json.dumps(r) + "\n")
    print(f"\nWROTE {len(rows)} validated rows -> {OUT}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())