dmaheshwar22's picture
deploy: replace template with real demo
0dd7c80 verified
"""MBPP train-side loader (disjoint from MBPP+ test set).
EvalPlus draws MBPP+ task_ids from across the whole MBPP dataset, not
just HuggingFace's "test" split. So loading `mbpp[train]` alone overlaps
MBPP+ by ~107 tasks. This module loads train + validation + prompt
splits, then explicitly filters out anything whose task_id appears in
MBPP+, returning ~345 contamination-free `Task` records.
Used by:
- `scripts/build_sft_dataset.py` — Week 3 rejection-sampling source
- `scripts/build_grpo_dataset.py` — Week 4 GRPO prompt source
"""
from __future__ import annotations
import re
from collections.abc import Iterable
from typing import Any, cast
from datasets import load_dataset # type: ignore[import-untyped]
from .base import Task
from .mbpp_plus import load_mbpp_plus
def infer_entry_point(test_list: list[str]) -> str | None:
"""Extract the function name from `assert name(args) == expected`."""
for assertion in test_list:
if m := re.search(r"assert\s+(\w+)\(", assertion):
return m.group(1)
return None
def extract_helpers(canonical_solution: str, entry_point: str) -> str:
"""Return everything before the entry_point's `def` line.
Captures imports, helper classes (e.g. `Pair`), and helper functions
that the canonical solution defines alongside the target function.
Returns "" if the entry_point's def isn't found, or the canonical
solution starts with the function (no preamble).
Whitespace is normalized — MBPP's source contains some Windows-style
CRLF + tab indentation. We convert to Unix LF + 4-space indents so
the prompt looks clean and the model doesn't pick up tab habits.
"""
pattern = re.compile(
rf"^def\s+{re.escape(entry_point)}\s*\(",
re.MULTILINE,
)
match = pattern.search(canonical_solution)
if not match:
return ""
raw = canonical_solution[: match.start()].rstrip()
return raw.replace("\r\n", "\n").replace("\r", "\n").expandtabs(4)
def build_mbpp_prompt(text: str, test_list: list[str], helpers: str = "") -> str:
"""Compose the user-turn prompt for an MBPP-style task.
If `helpers` is non-empty (i.e. the task's canonical solution defines
a class/import the tests rely on), include it in the prompt so the
model knows what types/symbols it can use.
"""
example = test_list[0].strip() if test_list else ""
parts: list[str] = [f"Task: {text.strip()}", ""]
if helpers:
parts.extend([
"Supporting definitions (already available — do not redefine):",
"```python",
helpers,
"```",
"",
])
parts.extend([
"Your function must satisfy this example:",
f" {example}",
"",
"Write the complete Python function.",
])
return "\n".join(parts)
def load_mbpp_train() -> list[Task]:
"""Return MBPP tasks disjoint from MBPP+ (our held-out eval set)."""
plus_ids: set[int] = {
int(t.task_id.split("/")[-1]) for t in load_mbpp_plus()
}
tasks: list[Task] = []
for split in ("train", "validation", "prompt"):
ds = cast("Any", load_dataset("mbpp", split=split))
for item in ds:
item_d = cast("dict[str, Any]", item)
task_id_int = int(item_d["task_id"])
if task_id_int in plus_ids:
continue
test_list: list[str] = list(item_d.get("test_list") or [])
entry_point = infer_entry_point(test_list)
if not entry_point or not test_list:
continue
canonical = str(item_d["code"])
helpers = extract_helpers(canonical, entry_point)
tasks.append(
Task(
task_id=f"Mbpp/{task_id_int}",
prompt=build_mbpp_prompt(str(item_d["text"]), test_list, helpers),
canonical_solution=canonical,
test="\n".join(test_list),
entry_point=entry_point,
helpers=helpers,
)
)
return tasks
def check_no_mbpp_plus_contamination(train_tasks: Iterable[Task]) -> None:
"""Belt-and-braces: confirm zero overlap between train + MBPP+ test."""
plus_ids = {int(t.task_id.split("/")[-1]) for t in load_mbpp_plus()}
train_ids = {int(t.task_id.split("/")[-1]) for t in train_tasks}
overlap = train_ids & plus_ids
if overlap:
raise RuntimeError(
f"CONTAMINATION: {len(overlap)} task_ids overlap MBPP+ test. "
f"Examples: {sorted(overlap)[:5]}"
)