"""MBPP train-side loader (disjoint from MBPP+ test set). EvalPlus draws MBPP+ task_ids from across the whole MBPP dataset, not just HuggingFace's "test" split. So loading `mbpp[train]` alone overlaps MBPP+ by ~107 tasks. This module loads train + validation + prompt splits, then explicitly filters out anything whose task_id appears in MBPP+, returning ~345 contamination-free `Task` records. Used by: - `scripts/build_sft_dataset.py` — Week 3 rejection-sampling source - `scripts/build_grpo_dataset.py` — Week 4 GRPO prompt source """ from __future__ import annotations import re from collections.abc import Iterable from typing import Any, cast from datasets import load_dataset # type: ignore[import-untyped] from .base import Task from .mbpp_plus import load_mbpp_plus def infer_entry_point(test_list: list[str]) -> str | None: """Extract the function name from `assert name(args) == expected`.""" for assertion in test_list: if m := re.search(r"assert\s+(\w+)\(", assertion): return m.group(1) return None def extract_helpers(canonical_solution: str, entry_point: str) -> str: """Return everything before the entry_point's `def` line. Captures imports, helper classes (e.g. `Pair`), and helper functions that the canonical solution defines alongside the target function. Returns "" if the entry_point's def isn't found, or the canonical solution starts with the function (no preamble). Whitespace is normalized — MBPP's source contains some Windows-style CRLF + tab indentation. We convert to Unix LF + 4-space indents so the prompt looks clean and the model doesn't pick up tab habits. """ pattern = re.compile( rf"^def\s+{re.escape(entry_point)}\s*\(", re.MULTILINE, ) match = pattern.search(canonical_solution) if not match: return "" raw = canonical_solution[: match.start()].rstrip() return raw.replace("\r\n", "\n").replace("\r", "\n").expandtabs(4) def build_mbpp_prompt(text: str, test_list: list[str], helpers: str = "") -> str: """Compose the user-turn prompt for an MBPP-style task. If `helpers` is non-empty (i.e. the task's canonical solution defines a class/import the tests rely on), include it in the prompt so the model knows what types/symbols it can use. """ example = test_list[0].strip() if test_list else "" parts: list[str] = [f"Task: {text.strip()}", ""] if helpers: parts.extend([ "Supporting definitions (already available — do not redefine):", "```python", helpers, "```", "", ]) parts.extend([ "Your function must satisfy this example:", f" {example}", "", "Write the complete Python function.", ]) return "\n".join(parts) def load_mbpp_train() -> list[Task]: """Return MBPP tasks disjoint from MBPP+ (our held-out eval set).""" plus_ids: set[int] = { int(t.task_id.split("/")[-1]) for t in load_mbpp_plus() } tasks: list[Task] = [] for split in ("train", "validation", "prompt"): ds = cast("Any", load_dataset("mbpp", split=split)) for item in ds: item_d = cast("dict[str, Any]", item) task_id_int = int(item_d["task_id"]) if task_id_int in plus_ids: continue test_list: list[str] = list(item_d.get("test_list") or []) entry_point = infer_entry_point(test_list) if not entry_point or not test_list: continue canonical = str(item_d["code"]) helpers = extract_helpers(canonical, entry_point) tasks.append( Task( task_id=f"Mbpp/{task_id_int}", prompt=build_mbpp_prompt(str(item_d["text"]), test_list, helpers), canonical_solution=canonical, test="\n".join(test_list), entry_point=entry_point, helpers=helpers, ) ) return tasks def check_no_mbpp_plus_contamination(train_tasks: Iterable[Task]) -> None: """Belt-and-braces: confirm zero overlap between train + MBPP+ test.""" plus_ids = {int(t.task_id.split("/")[-1]) for t in load_mbpp_plus()} train_ids = {int(t.task_id.split("/")[-1]) for t in train_tasks} overlap = train_ids & plus_ids if overlap: raise RuntimeError( f"CONTAMINATION: {len(overlap)} task_ids overlap MBPP+ test. " f"Examples: {sorted(overlap)[:5]}" )