File size: 7,489 Bytes
62dca4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
HumanEval benchmark evaluation script.
"""

import re
from typing import Any, Dict, List, Optional, Tuple

from datasets import load_dataset

from .base import Benchmarker
from .registry import BENCHMARKS
from .utils import create_simple_sgl_function


def extract_code_from_output(output: str) -> Optional[str]:
    """Extract Python code from model output.

    Tries to extract code blocks or function definitions.
    """
    # Try to find code in markdown code blocks
    code_block_pattern = r"```(?:python)?\n(.*?)```"
    match = re.search(code_block_pattern, output, re.DOTALL)
    if match:
        return match.group(1).strip()

    # Try to find function definition (common in HumanEval)
    # Look for "def " followed by code until the next def or end of string
    def_pattern = r"(def\s+\w+\([^)]*\):.*?)(?=\n\ndef\s+|\Z)"
    match = re.search(def_pattern, output, re.DOTALL)
    if match:
        return match.group(1).strip()

    # Fallback: return the output as-is (might already be code)
    return output.strip() if output.strip() else None


def check_code_passes_tests(code: str, test_code: str, entry_point: str) -> bool:
    """Check if generated code passes the test cases.

    This is a simplified version. For full evaluation, use the official
    HumanEval evaluation framework.

    HumanEval test code typically contains assertions that will raise
    AssertionError if the code doesn't pass. If execution completes without
    exceptions, the tests pass.
    """
    try:
        # Create a safe execution environment
        namespace = {}
        # Execute the code (function definition)
        exec(code, namespace)
        # Execute the test code (which contains assertions)
        # If no exception is raised, the tests pass
        exec(test_code, namespace)
        return True
    except AssertionError:
        # Assertion failed - test didn't pass
        return False
    except Exception:
        # Any other exception (syntax error, runtime error, etc.) means test failed
        return False


@BENCHMARKS.register("humaneval")
class HumanEvalBenchmarker(Benchmarker):
    """HumanEval benchmark implementation."""

    def __init__(self, num_samples: Optional[int] = None):
        """Initialize benchmark and store test cases."""
        super().__init__(num_samples, None)
        self.test_cases = []
        self.entry_points = []

    def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[Dict[str, str]]]]:
        """Load and preprocess HumanEval dataset."""
        dataset = load_dataset("openai/openai_humaneval")["test"]
        questions = []
        labels = []
        self.test_cases = []
        self.entry_points = []

        for idx, q in enumerate(dataset):
            if self.num_samples is not None and idx >= self.num_samples:
                break

            questions.append({"question": q["prompt"]})

            # Store test case and entry point for evaluation
            test_code = q.get("test", "")
            entry_point = q.get("entry_point", "")
            self.test_cases.append(test_code)
            self.entry_points.append(entry_point)

            # Store canonical solution as reference (optional, for comparison)
            canonical_solution = q.get("canonical_solution", "")
            labels.append(
                {
                    "test": test_code,
                    "entry_point": entry_point,
                    "canonical_solution": canonical_solution,
                }
            )

        return questions, labels

    def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
        """Extract code from model output."""
        return extract_code_from_output(output)

    def compute_accuracy(
        self, predictions: List[Any], labels: List[Any]
    ) -> Optional[float]:
        """Compute accuracy for HumanEval by checking if code passes tests.

        Note: This is a simplified evaluation. For official pass@k metrics,
        use the HumanEval evaluation framework.
        """
        if not labels or len(labels) == 0:
            return None
        if all(label is None for label in labels):
            return None

        correct = 0
        valid_count = 0

        for i, (pred, label) in enumerate(zip(predictions, labels)):
            if label is not None and isinstance(label, dict):
                valid_count += 1
                if pred is not None:
                    try:
                        # Get the prompt (function signature and docstring)
                        prompt = self.questions[i]["question"]
                        entry_point = label.get("entry_point", "")

                        # The prompt contains the function signature (e.g., "def function_name(...):")
                        # The generated code might be:
                        # 1. Just the function body (what we want) - need to combine with prompt
                        # 2. The complete function including signature - use as-is
                        # 3. Code in markdown blocks - already extracted by extract_code_from_output

                        pred_str = str(pred).strip()

                        # Check if pred already contains a complete function definition
                        # (starts with "def " and contains the entry_point function name)
                        if pred_str.startswith("def ") and entry_point:
                            # Check if this is the same function (by name)
                            func_name_match = re.match(r"def\s+(\w+)\s*\(", pred_str)
                            if (
                                func_name_match
                                and func_name_match.group(1) == entry_point
                            ):
                                # Generated code includes complete function, use it as-is
                                full_code = pred_str
                            else:
                                # Different function or no match, combine with prompt
                                full_code = prompt + "\n" + pred_str
                        elif pred_str.startswith("def "):
                            # Has function definition but we can't verify entry_point, use as-is
                            full_code = pred_str
                        else:
                            # Generated code is just the body, combine with prompt
                            full_code = prompt + "\n" + pred_str

                        # Check if code passes tests
                        test_code = label.get("test", "")

                        if test_code and check_code_passes_tests(
                            full_code, test_code, entry_point
                        ):
                            correct += 1
                    except Exception as e:
                        # If evaluation fails, consider it incorrect
                        # Uncomment for debugging: print(f"Error evaluating code {i}: {e}")
                        pass

        return correct / valid_count if valid_count > 0 else 0.0

    def create_sgl_function(self):
        """Create SGL function for HumanEval."""
        return create_simple_sgl_function(
            function_name="get_humaneval_answer",
            answer_key="answer",
            max_tokens=self.get_max_new_tokens(),
        )

    def get_max_new_tokens(self) -> int:
        """HumanEval code generation requires more tokens."""
        return 1024