File size: 4,329 Bytes
62dca4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
MATH-500 benchmark evaluation script.
"""

import re
from typing import Any, Dict, List, Optional, Tuple

from datasets import load_dataset

from .base import Benchmarker
from .registry import BENCHMARKS
from .utils import create_simple_sgl_function


def extract_math_answer(output: str) -> Optional[str]:
    """Extract final answer from math problem solution.

    Tries to extract answer from \boxed{} format first, then looks for
    the last number in the output.
    """
    # Try to find answer in \boxed{} format
    boxed_pattern = r"\\boxed\{([^}]+)\}"
    match = re.search(boxed_pattern, output)
    if match:
        return match.group(1).strip()

    # Try to find answer in \boxed format (without braces)
    boxed_pattern2 = r"\\boxed\s+([^\s]+)"
    match = re.search(boxed_pattern2, output)
    if match:
        return match.group(1).strip()

    # Try to find the last number (could be integer or decimal)
    # Look for patterns like "The answer is 42" or "Answer: 3.14"
    answer_patterns = [
        r"(?:answer|Answer|ANSWER)[\s:]+([-+]?\d*\.?\d+)",
        r"(?:is|equals?|=\s*)([-+]?\d*\.?\d+)\s*$",
    ]
    for pattern in answer_patterns:
        matches = re.findall(pattern, output, re.IGNORECASE)
        if matches:
            return matches[-1].strip()

    # Fallback: extract the last number in the text
    numbers = re.findall(r"[-+]?\d*\.?\d+", output)
    if numbers:
        return numbers[-1]

    return None


@BENCHMARKS.register("math500")
class Math500Benchmarker(Benchmarker):
    """MATH-500 benchmark implementation."""

    def __init__(self, num_samples: Optional[int] = None):
        super().__init__(num_samples, None)

    def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
        """Load and preprocess MATH-500 dataset."""
        dataset = load_dataset("HuggingFaceH4/MATH-500")["test"]
        questions = []
        labels = []
        for idx, q in enumerate(dataset):
            if self.num_samples is not None and idx >= self.num_samples:
                break

            questions.append({"question": q["problem"]})
            # Extract answer from solution or answer field
            answer = None
            if "answer" in q:
                answer = str(q["answer"]).strip()
            elif "solution" in q:
                # Try to extract from solution
                answer = extract_math_answer(q["solution"])
            labels.append(answer)
        return questions, labels

    def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
        """Extract answer from model output."""
        return extract_math_answer(output)

    def compute_accuracy(
        self, predictions: List[Any], labels: List[Any]
    ) -> Optional[float]:
        """Compute accuracy for MATH-500 by comparing answers."""
        if not labels or len(labels) == 0:
            return None
        if all(label is None for label in labels):
            return None

        correct = 0
        valid_count = 0
        for pred, label in zip(predictions, labels):
            if label is not None:
                valid_count += 1
                if pred is not None:
                    # Normalize answers for comparison (remove whitespace, handle different formats)
                    pred_normalized = str(pred).strip().lower()
                    label_normalized = str(label).strip().lower()
                    # Try exact match first
                    if pred_normalized == label_normalized:
                        correct += 1
                    else:
                        # Try numeric comparison if both are numbers
                        try:
                            pred_num = float(pred_normalized)
                            label_num = float(label_normalized)
                            if abs(pred_num - label_num) < 1e-6:
                                correct += 1
                        except ValueError:
                            pass

        return correct / valid_count if valid_count > 0 else 0.0

    def create_sgl_function(self):
        """Create SGL function for MATH-500."""
        return create_simple_sgl_function(
            function_name="get_math500_answer",
            answer_key="answer",
            max_tokens=self.get_max_new_tokens(),
        )