File size: 4,639 Bytes
62dca4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
AIME benchmark
"""

import re
from typing import Any, Dict, List, Optional, Tuple

from datasets import load_dataset

from .base import Benchmarker
from .registry import BENCHMARKS
from .utils import create_simple_sgl_function


def extract_aime_answer(output: str) -> Optional[str]:
    """Extract final answer from AIME problem solution.

    AIME answers are typically integers between 0 and 999, and are usually
    in \boxed{} format.
    """
    # Try to find answer in \boxed{} format
    boxed_pattern = r"\\boxed\{([^}]+)\}"
    match = re.search(boxed_pattern, output)
    if match:
        answer = match.group(1).strip()
        # Extract number from the boxed content
        numbers = re.findall(r"\d+", answer)
        if numbers:
            return numbers[-1]  # Take the last number (usually the final answer)
        return answer

    # Try to find answer in \boxed format (without braces)
    boxed_pattern2 = r"\\boxed\s+(\d+)"
    match = re.search(boxed_pattern2, output)
    if match:
        return match.group(1).strip()

    # Look for patterns like "The answer is 42" or "Answer: 123"
    answer_patterns = [
        r"(?:answer|Answer|ANSWER)[\s:]+(\d+)",
        r"(?:final\s+answer|Final\s+Answer)[\s:]+(\d+)",
        r"(?:is|equals?|=\s*)(\d+)\s*$",
    ]
    for pattern in answer_patterns:
        matches = re.findall(pattern, output, re.IGNORECASE)
        if matches:
            return matches[-1].strip()

    # Fallback: extract the last integer in the text
    numbers = re.findall(r"\b(\d+)\b", output)
    if numbers:
        # Filter to reasonable AIME answer range (0-999)
        valid_numbers = [n for n in numbers if 0 <= int(n) <= 999]
        if valid_numbers:
            return valid_numbers[-1]

    return None


@BENCHMARKS.register("aime")
class AIMEBenchmarker(Benchmarker):
    """AIME benchmark implementation."""

    def __init__(self, num_samples: Optional[int] = None):
        super().__init__(num_samples, None)

    def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
        """Load and preprocess AIME dataset."""
        dataset = load_dataset("Maxwell-Jia/AIME_2024")["train"]
        questions = []
        labels = []
        for idx, q in enumerate(dataset):
            if self.num_samples is not None and idx >= self.num_samples:
                break

            questions.append({"question": q["Problem"]})
            # Extract answer from Answer field
            answer = None
            if "Answer" in q:
                answer = str(q["Answer"]).strip()
            elif "answer" in q:
                answer = str(q["answer"]).strip()
            labels.append(answer)
        return questions, labels

    def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
        """Extract answer from model output."""
        return extract_aime_answer(output)

    def compute_accuracy(
        self, predictions: List[Any], labels: List[Any]
    ) -> Optional[float]:
        """Compute accuracy for AIME by comparing numeric answers."""
        if not labels or len(labels) == 0:
            return None
        if all(label is None for label in labels):
            return None

        correct = 0
        valid_count = 0
        for pred, label in zip(predictions, labels):
            if label is not None:
                valid_count += 1
                if pred is not None:
                    # Normalize answers for comparison
                    pred_normalized = str(pred).strip()
                    label_normalized = str(label).strip()
                    # Try exact match first
                    if pred_normalized == label_normalized:
                        correct += 1
                    else:
                        # Try numeric comparison
                        try:
                            pred_num = int(pred_normalized)
                            label_num = int(label_normalized)
                            if pred_num == label_num:
                                correct += 1
                        except ValueError:
                            pass

        return correct / valid_count if valid_count > 0 else 0.0

    def create_sgl_function(self):
        """Create SGL function for AIME with reasoning prompt."""
        return create_simple_sgl_function(
            function_name="reasoning_gen",
            answer_key="answer",
            user_prefix="\nPlease reason step by step, and put your final answer within \\boxed{}.",
        )

    def get_max_new_tokens(self) -> int:
        """AIME problems require more tokens."""
        return 32768