Spaces:
Sleeping
Sleeping
| """MBPP train-side loader (disjoint from MBPP+ test set). | |
| EvalPlus draws MBPP+ task_ids from across the whole MBPP dataset, not | |
| just HuggingFace's "test" split. So loading `mbpp[train]` alone overlaps | |
| MBPP+ by ~107 tasks. This module loads train + validation + prompt | |
| splits, then explicitly filters out anything whose task_id appears in | |
| MBPP+, returning ~345 contamination-free `Task` records. | |
| Used by: | |
| - `scripts/build_sft_dataset.py` — Week 3 rejection-sampling source | |
| - `scripts/build_grpo_dataset.py` — Week 4 GRPO prompt source | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from collections.abc import Iterable | |
| from typing import Any, cast | |
| from datasets import load_dataset # type: ignore[import-untyped] | |
| from .base import Task | |
| from .mbpp_plus import load_mbpp_plus | |
| def infer_entry_point(test_list: list[str]) -> str | None: | |
| """Extract the function name from `assert name(args) == expected`.""" | |
| for assertion in test_list: | |
| if m := re.search(r"assert\s+(\w+)\(", assertion): | |
| return m.group(1) | |
| return None | |
| def extract_helpers(canonical_solution: str, entry_point: str) -> str: | |
| """Return everything before the entry_point's `def` line. | |
| Captures imports, helper classes (e.g. `Pair`), and helper functions | |
| that the canonical solution defines alongside the target function. | |
| Returns "" if the entry_point's def isn't found, or the canonical | |
| solution starts with the function (no preamble). | |
| Whitespace is normalized — MBPP's source contains some Windows-style | |
| CRLF + tab indentation. We convert to Unix LF + 4-space indents so | |
| the prompt looks clean and the model doesn't pick up tab habits. | |
| """ | |
| pattern = re.compile( | |
| rf"^def\s+{re.escape(entry_point)}\s*\(", | |
| re.MULTILINE, | |
| ) | |
| match = pattern.search(canonical_solution) | |
| if not match: | |
| return "" | |
| raw = canonical_solution[: match.start()].rstrip() | |
| return raw.replace("\r\n", "\n").replace("\r", "\n").expandtabs(4) | |
| def build_mbpp_prompt(text: str, test_list: list[str], helpers: str = "") -> str: | |
| """Compose the user-turn prompt for an MBPP-style task. | |
| If `helpers` is non-empty (i.e. the task's canonical solution defines | |
| a class/import the tests rely on), include it in the prompt so the | |
| model knows what types/symbols it can use. | |
| """ | |
| example = test_list[0].strip() if test_list else "" | |
| parts: list[str] = [f"Task: {text.strip()}", ""] | |
| if helpers: | |
| parts.extend([ | |
| "Supporting definitions (already available — do not redefine):", | |
| "```python", | |
| helpers, | |
| "```", | |
| "", | |
| ]) | |
| parts.extend([ | |
| "Your function must satisfy this example:", | |
| f" {example}", | |
| "", | |
| "Write the complete Python function.", | |
| ]) | |
| return "\n".join(parts) | |
| def load_mbpp_train() -> list[Task]: | |
| """Return MBPP tasks disjoint from MBPP+ (our held-out eval set).""" | |
| plus_ids: set[int] = { | |
| int(t.task_id.split("/")[-1]) for t in load_mbpp_plus() | |
| } | |
| tasks: list[Task] = [] | |
| for split in ("train", "validation", "prompt"): | |
| ds = cast("Any", load_dataset("mbpp", split=split)) | |
| for item in ds: | |
| item_d = cast("dict[str, Any]", item) | |
| task_id_int = int(item_d["task_id"]) | |
| if task_id_int in plus_ids: | |
| continue | |
| test_list: list[str] = list(item_d.get("test_list") or []) | |
| entry_point = infer_entry_point(test_list) | |
| if not entry_point or not test_list: | |
| continue | |
| canonical = str(item_d["code"]) | |
| helpers = extract_helpers(canonical, entry_point) | |
| tasks.append( | |
| Task( | |
| task_id=f"Mbpp/{task_id_int}", | |
| prompt=build_mbpp_prompt(str(item_d["text"]), test_list, helpers), | |
| canonical_solution=canonical, | |
| test="\n".join(test_list), | |
| entry_point=entry_point, | |
| helpers=helpers, | |
| ) | |
| ) | |
| return tasks | |
| def check_no_mbpp_plus_contamination(train_tasks: Iterable[Task]) -> None: | |
| """Belt-and-braces: confirm zero overlap between train + MBPP+ test.""" | |
| plus_ids = {int(t.task_id.split("/")[-1]) for t in load_mbpp_plus()} | |
| train_ids = {int(t.task_id.split("/")[-1]) for t in train_tasks} | |
| overlap = train_ids & plus_ids | |
| if overlap: | |
| raise RuntimeError( | |
| f"CONTAMINATION: {len(overlap)} task_ids overlap MBPP+ test. " | |
| f"Examples: {sorted(overlap)[:5]}" | |
| ) | |