File size: 4,624 Bytes
0dd7c80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""MBPP train-side loader (disjoint from MBPP+ test set).

EvalPlus draws MBPP+ task_ids from across the whole MBPP dataset, not
just HuggingFace's "test" split. So loading `mbpp[train]` alone overlaps
MBPP+ by ~107 tasks. This module loads train + validation + prompt
splits, then explicitly filters out anything whose task_id appears in
MBPP+, returning ~345 contamination-free `Task` records.

Used by:
    - `scripts/build_sft_dataset.py` — Week 3 rejection-sampling source
    - `scripts/build_grpo_dataset.py` — Week 4 GRPO prompt source
"""

from __future__ import annotations

import re
from collections.abc import Iterable
from typing import Any, cast

from datasets import load_dataset  # type: ignore[import-untyped]

from .base import Task
from .mbpp_plus import load_mbpp_plus


def infer_entry_point(test_list: list[str]) -> str | None:
    """Extract the function name from `assert name(args) == expected`."""
    for assertion in test_list:
        if m := re.search(r"assert\s+(\w+)\(", assertion):
            return m.group(1)
    return None


def extract_helpers(canonical_solution: str, entry_point: str) -> str:
    """Return everything before the entry_point's `def` line.

    Captures imports, helper classes (e.g. `Pair`), and helper functions
    that the canonical solution defines alongside the target function.
    Returns "" if the entry_point's def isn't found, or the canonical
    solution starts with the function (no preamble).

    Whitespace is normalized — MBPP's source contains some Windows-style
    CRLF + tab indentation. We convert to Unix LF + 4-space indents so
    the prompt looks clean and the model doesn't pick up tab habits.
    """
    pattern = re.compile(
        rf"^def\s+{re.escape(entry_point)}\s*\(",
        re.MULTILINE,
    )
    match = pattern.search(canonical_solution)
    if not match:
        return ""
    raw = canonical_solution[: match.start()].rstrip()
    return raw.replace("\r\n", "\n").replace("\r", "\n").expandtabs(4)


def build_mbpp_prompt(text: str, test_list: list[str], helpers: str = "") -> str:
    """Compose the user-turn prompt for an MBPP-style task.

    If `helpers` is non-empty (i.e. the task's canonical solution defines
    a class/import the tests rely on), include it in the prompt so the
    model knows what types/symbols it can use.
    """
    example = test_list[0].strip() if test_list else ""
    parts: list[str] = [f"Task: {text.strip()}", ""]
    if helpers:
        parts.extend([
            "Supporting definitions (already available — do not redefine):",
            "```python",
            helpers,
            "```",
            "",
        ])
    parts.extend([
        "Your function must satisfy this example:",
        f"    {example}",
        "",
        "Write the complete Python function.",
    ])
    return "\n".join(parts)


def load_mbpp_train() -> list[Task]:
    """Return MBPP tasks disjoint from MBPP+ (our held-out eval set)."""
    plus_ids: set[int] = {
        int(t.task_id.split("/")[-1]) for t in load_mbpp_plus()
    }

    tasks: list[Task] = []
    for split in ("train", "validation", "prompt"):
        ds = cast("Any", load_dataset("mbpp", split=split))
        for item in ds:
            item_d = cast("dict[str, Any]", item)
            task_id_int = int(item_d["task_id"])
            if task_id_int in plus_ids:
                continue
            test_list: list[str] = list(item_d.get("test_list") or [])
            entry_point = infer_entry_point(test_list)
            if not entry_point or not test_list:
                continue
            canonical = str(item_d["code"])
            helpers = extract_helpers(canonical, entry_point)
            tasks.append(
                Task(
                    task_id=f"Mbpp/{task_id_int}",
                    prompt=build_mbpp_prompt(str(item_d["text"]), test_list, helpers),
                    canonical_solution=canonical,
                    test="\n".join(test_list),
                    entry_point=entry_point,
                    helpers=helpers,
                )
            )
    return tasks


def check_no_mbpp_plus_contamination(train_tasks: Iterable[Task]) -> None:
    """Belt-and-braces: confirm zero overlap between train + MBPP+ test."""
    plus_ids = {int(t.task_id.split("/")[-1]) for t in load_mbpp_plus()}
    train_ids = {int(t.task_id.split("/")[-1]) for t in train_tasks}
    overlap = train_ids & plus_ids
    if overlap:
        raise RuntimeError(
            f"CONTAMINATION: {len(overlap)} task_ids overlap MBPP+ test. "
            f"Examples: {sorted(overlap)[:5]}"
        )