File size: 4,329 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | """
MATH-500 benchmark evaluation script.
"""
import re
from typing import Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from .base import Benchmarker
from .registry import BENCHMARKS
from .utils import create_simple_sgl_function
def extract_math_answer(output: str) -> Optional[str]:
"""Extract final answer from math problem solution.
Tries to extract answer from \boxed{} format first, then looks for
the last number in the output.
"""
# Try to find answer in \boxed{} format
boxed_pattern = r"\\boxed\{([^}]+)\}"
match = re.search(boxed_pattern, output)
if match:
return match.group(1).strip()
# Try to find answer in \boxed format (without braces)
boxed_pattern2 = r"\\boxed\s+([^\s]+)"
match = re.search(boxed_pattern2, output)
if match:
return match.group(1).strip()
# Try to find the last number (could be integer or decimal)
# Look for patterns like "The answer is 42" or "Answer: 3.14"
answer_patterns = [
r"(?:answer|Answer|ANSWER)[\s:]+([-+]?\d*\.?\d+)",
r"(?:is|equals?|=\s*)([-+]?\d*\.?\d+)\s*$",
]
for pattern in answer_patterns:
matches = re.findall(pattern, output, re.IGNORECASE)
if matches:
return matches[-1].strip()
# Fallback: extract the last number in the text
numbers = re.findall(r"[-+]?\d*\.?\d+", output)
if numbers:
return numbers[-1]
return None
@BENCHMARKS.register("math500")
class Math500Benchmarker(Benchmarker):
"""MATH-500 benchmark implementation."""
def __init__(self, num_samples: Optional[int] = None):
super().__init__(num_samples, None)
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
"""Load and preprocess MATH-500 dataset."""
dataset = load_dataset("HuggingFaceH4/MATH-500")["test"]
questions = []
labels = []
for idx, q in enumerate(dataset):
if self.num_samples is not None and idx >= self.num_samples:
break
questions.append({"question": q["problem"]})
# Extract answer from solution or answer field
answer = None
if "answer" in q:
answer = str(q["answer"]).strip()
elif "solution" in q:
# Try to extract from solution
answer = extract_math_answer(q["solution"])
labels.append(answer)
return questions, labels
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
"""Extract answer from model output."""
return extract_math_answer(output)
def compute_accuracy(
self, predictions: List[Any], labels: List[Any]
) -> Optional[float]:
"""Compute accuracy for MATH-500 by comparing answers."""
if not labels or len(labels) == 0:
return None
if all(label is None for label in labels):
return None
correct = 0
valid_count = 0
for pred, label in zip(predictions, labels):
if label is not None:
valid_count += 1
if pred is not None:
# Normalize answers for comparison (remove whitespace, handle different formats)
pred_normalized = str(pred).strip().lower()
label_normalized = str(label).strip().lower()
# Try exact match first
if pred_normalized == label_normalized:
correct += 1
else:
# Try numeric comparison if both are numbers
try:
pred_num = float(pred_normalized)
label_num = float(label_normalized)
if abs(pred_num - label_num) < 1e-6:
correct += 1
except ValueError:
pass
return correct / valid_count if valid_count > 0 else 0.0
def create_sgl_function(self):
"""Create SGL function for MATH-500."""
return create_simple_sgl_function(
function_name="get_math500_answer",
answer_key="answer",
max_tokens=self.get_max_new_tokens(),
)
|