| """MBPP+ loader (EvalPlus-augmented MBPP, ~378 sanitized tasks). |
| |
| MBPP+ applies the same test-augmentation treatment as HumanEval+: the original |
| MBPP has very thin tests (usually 3 assertions per task); MBPP+ expands each |
| problem to ~100 tests and filters out problems whose ground truth is |
| ambiguous or buggy. |
| |
| Same API as `load_humaneval_plus` — returns a deterministically ordered list |
| of `Task`. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Any, cast |
|
|
| from evalplus.data import get_mbpp_plus |
|
|
| from .base import Task |
|
|
|
|
| def _extract_mbpp_helpers(canonical_solution: str, entry_point: str) -> str: |
| """Same helper-extraction logic as mbpp_train (kept independent to |
| avoid an import cycle between mbpp_plus and mbpp_train).""" |
| import re |
| pattern = re.compile( |
| rf"^def\s+{re.escape(entry_point)}\s*\(", |
| re.MULTILINE, |
| ) |
| match = pattern.search(canonical_solution) |
| if not match: |
| return "" |
| raw = canonical_solution[: match.start()].rstrip() |
| return raw.replace("\r\n", "\n").replace("\r", "\n").expandtabs(4) |
|
|
|
|
| def load_mbpp_plus() -> list[Task]: |
| """Return all MBPP+ tasks in deterministic task_id order. |
| |
| MBPP+ stores tests under `assertion` (raw `assert ...` lines referencing |
| the entry_point by name), not `test` (HumanEval+'s `def check(candidate)` |
| wrapper). We load the assertion block verbatim into `Task.test`. We also |
| extract any preamble (imports, classes) from `canonical_solution` into |
| `Task.helpers` so the sandbox executor can prepend them at run time. |
| """ |
| raw = cast("dict[str, dict[str, Any]]", get_mbpp_plus()) |
| tasks: list[Task] = [] |
| for task_id, item in raw.items(): |
| canonical = item["canonical_solution"] |
| entry_point = item["entry_point"] |
| tasks.append( |
| Task( |
| task_id=task_id, |
| prompt=item["prompt"], |
| canonical_solution=canonical, |
| test=item["assertion"], |
| entry_point=entry_point, |
| helpers=_extract_mbpp_helpers(canonical, entry_point), |
| ) |
| ) |
| tasks.sort(key=lambda t: int(t.task_id.split("/")[-1])) |
| return tasks |
|
|