File size: 2,225 Bytes
0dd7c80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""MBPP+ loader (EvalPlus-augmented MBPP, ~378 sanitized tasks).

MBPP+ applies the same test-augmentation treatment as HumanEval+: the original
MBPP has very thin tests (usually 3 assertions per task); MBPP+ expands each
problem to ~100 tests and filters out problems whose ground truth is
ambiguous or buggy.

Same API as `load_humaneval_plus` — returns a deterministically ordered list
of `Task`.
"""

from __future__ import annotations

from typing import Any, cast

from evalplus.data import get_mbpp_plus  # type: ignore[import-untyped]

from .base import Task


def _extract_mbpp_helpers(canonical_solution: str, entry_point: str) -> str:
    """Same helper-extraction logic as mbpp_train (kept independent to
    avoid an import cycle between mbpp_plus and mbpp_train)."""
    import re
    pattern = re.compile(
        rf"^def\s+{re.escape(entry_point)}\s*\(",
        re.MULTILINE,
    )
    match = pattern.search(canonical_solution)
    if not match:
        return ""
    raw = canonical_solution[: match.start()].rstrip()
    return raw.replace("\r\n", "\n").replace("\r", "\n").expandtabs(4)


def load_mbpp_plus() -> list[Task]:
    """Return all MBPP+ tasks in deterministic task_id order.

    MBPP+ stores tests under `assertion` (raw `assert ...` lines referencing
    the entry_point by name), not `test` (HumanEval+'s `def check(candidate)`
    wrapper). We load the assertion block verbatim into `Task.test`. We also
    extract any preamble (imports, classes) from `canonical_solution` into
    `Task.helpers` so the sandbox executor can prepend them at run time.
    """
    raw = cast("dict[str, dict[str, Any]]", get_mbpp_plus())
    tasks: list[Task] = []
    for task_id, item in raw.items():
        canonical = item["canonical_solution"]
        entry_point = item["entry_point"]
        tasks.append(
            Task(
                task_id=task_id,
                prompt=item["prompt"],
                canonical_solution=canonical,
                test=item["assertion"],
                entry_point=entry_point,
                helpers=_extract_mbpp_helpers(canonical, entry_point),
            )
        )
    tasks.sort(key=lambda t: int(t.task_id.split("/")[-1]))
    return tasks