File size: 7,489 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """
HumanEval benchmark evaluation script.
"""
import re
from typing import Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from .base import Benchmarker
from .registry import BENCHMARKS
from .utils import create_simple_sgl_function
def extract_code_from_output(output: str) -> Optional[str]:
"""Extract Python code from model output.
Tries to extract code blocks or function definitions.
"""
# Try to find code in markdown code blocks
code_block_pattern = r"```(?:python)?\n(.*?)```"
match = re.search(code_block_pattern, output, re.DOTALL)
if match:
return match.group(1).strip()
# Try to find function definition (common in HumanEval)
# Look for "def " followed by code until the next def or end of string
def_pattern = r"(def\s+\w+\([^)]*\):.*?)(?=\n\ndef\s+|\Z)"
match = re.search(def_pattern, output, re.DOTALL)
if match:
return match.group(1).strip()
# Fallback: return the output as-is (might already be code)
return output.strip() if output.strip() else None
def check_code_passes_tests(code: str, test_code: str, entry_point: str) -> bool:
"""Check if generated code passes the test cases.
This is a simplified version. For full evaluation, use the official
HumanEval evaluation framework.
HumanEval test code typically contains assertions that will raise
AssertionError if the code doesn't pass. If execution completes without
exceptions, the tests pass.
"""
try:
# Create a safe execution environment
namespace = {}
# Execute the code (function definition)
exec(code, namespace)
# Execute the test code (which contains assertions)
# If no exception is raised, the tests pass
exec(test_code, namespace)
return True
except AssertionError:
# Assertion failed - test didn't pass
return False
except Exception:
# Any other exception (syntax error, runtime error, etc.) means test failed
return False
@BENCHMARKS.register("humaneval")
class HumanEvalBenchmarker(Benchmarker):
"""HumanEval benchmark implementation."""
def __init__(self, num_samples: Optional[int] = None):
"""Initialize benchmark and store test cases."""
super().__init__(num_samples, None)
self.test_cases = []
self.entry_points = []
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[Dict[str, str]]]]:
"""Load and preprocess HumanEval dataset."""
dataset = load_dataset("openai/openai_humaneval")["test"]
questions = []
labels = []
self.test_cases = []
self.entry_points = []
for idx, q in enumerate(dataset):
if self.num_samples is not None and idx >= self.num_samples:
break
questions.append({"question": q["prompt"]})
# Store test case and entry point for evaluation
test_code = q.get("test", "")
entry_point = q.get("entry_point", "")
self.test_cases.append(test_code)
self.entry_points.append(entry_point)
# Store canonical solution as reference (optional, for comparison)
canonical_solution = q.get("canonical_solution", "")
labels.append(
{
"test": test_code,
"entry_point": entry_point,
"canonical_solution": canonical_solution,
}
)
return questions, labels
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
"""Extract code from model output."""
return extract_code_from_output(output)
def compute_accuracy(
self, predictions: List[Any], labels: List[Any]
) -> Optional[float]:
"""Compute accuracy for HumanEval by checking if code passes tests.
Note: This is a simplified evaluation. For official pass@k metrics,
use the HumanEval evaluation framework.
"""
if not labels or len(labels) == 0:
return None
if all(label is None for label in labels):
return None
correct = 0
valid_count = 0
for i, (pred, label) in enumerate(zip(predictions, labels)):
if label is not None and isinstance(label, dict):
valid_count += 1
if pred is not None:
try:
# Get the prompt (function signature and docstring)
prompt = self.questions[i]["question"]
entry_point = label.get("entry_point", "")
# The prompt contains the function signature (e.g., "def function_name(...):")
# The generated code might be:
# 1. Just the function body (what we want) - need to combine with prompt
# 2. The complete function including signature - use as-is
# 3. Code in markdown blocks - already extracted by extract_code_from_output
pred_str = str(pred).strip()
# Check if pred already contains a complete function definition
# (starts with "def " and contains the entry_point function name)
if pred_str.startswith("def ") and entry_point:
# Check if this is the same function (by name)
func_name_match = re.match(r"def\s+(\w+)\s*\(", pred_str)
if (
func_name_match
and func_name_match.group(1) == entry_point
):
# Generated code includes complete function, use it as-is
full_code = pred_str
else:
# Different function or no match, combine with prompt
full_code = prompt + "\n" + pred_str
elif pred_str.startswith("def "):
# Has function definition but we can't verify entry_point, use as-is
full_code = pred_str
else:
# Generated code is just the body, combine with prompt
full_code = prompt + "\n" + pred_str
# Check if code passes tests
test_code = label.get("test", "")
if test_code and check_code_passes_tests(
full_code, test_code, entry_point
):
correct += 1
except Exception as e:
# If evaluation fails, consider it incorrect
# Uncomment for debugging: print(f"Error evaluating code {i}: {e}")
pass
return correct / valid_count if valid_count > 0 else 0.0
def create_sgl_function(self):
"""Create SGL function for HumanEval."""
return create_simple_sgl_function(
function_name="get_humaneval_answer",
answer_key="answer",
max_tokens=self.get_max_new_tokens(),
)
def get_max_new_tokens(self) -> int:
"""HumanEval code generation requires more tokens."""
return 1024
|