File size: 4,639 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """
AIME benchmark
"""
import re
from typing import Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from .base import Benchmarker
from .registry import BENCHMARKS
from .utils import create_simple_sgl_function
def extract_aime_answer(output: str) -> Optional[str]:
"""Extract final answer from AIME problem solution.
AIME answers are typically integers between 0 and 999, and are usually
in \boxed{} format.
"""
# Try to find answer in \boxed{} format
boxed_pattern = r"\\boxed\{([^}]+)\}"
match = re.search(boxed_pattern, output)
if match:
answer = match.group(1).strip()
# Extract number from the boxed content
numbers = re.findall(r"\d+", answer)
if numbers:
return numbers[-1] # Take the last number (usually the final answer)
return answer
# Try to find answer in \boxed format (without braces)
boxed_pattern2 = r"\\boxed\s+(\d+)"
match = re.search(boxed_pattern2, output)
if match:
return match.group(1).strip()
# Look for patterns like "The answer is 42" or "Answer: 123"
answer_patterns = [
r"(?:answer|Answer|ANSWER)[\s:]+(\d+)",
r"(?:final\s+answer|Final\s+Answer)[\s:]+(\d+)",
r"(?:is|equals?|=\s*)(\d+)\s*$",
]
for pattern in answer_patterns:
matches = re.findall(pattern, output, re.IGNORECASE)
if matches:
return matches[-1].strip()
# Fallback: extract the last integer in the text
numbers = re.findall(r"\b(\d+)\b", output)
if numbers:
# Filter to reasonable AIME answer range (0-999)
valid_numbers = [n for n in numbers if 0 <= int(n) <= 999]
if valid_numbers:
return valid_numbers[-1]
return None
@BENCHMARKS.register("aime")
class AIMEBenchmarker(Benchmarker):
"""AIME benchmark implementation."""
def __init__(self, num_samples: Optional[int] = None):
super().__init__(num_samples, None)
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
"""Load and preprocess AIME dataset."""
dataset = load_dataset("Maxwell-Jia/AIME_2024")["train"]
questions = []
labels = []
for idx, q in enumerate(dataset):
if self.num_samples is not None and idx >= self.num_samples:
break
questions.append({"question": q["Problem"]})
# Extract answer from Answer field
answer = None
if "Answer" in q:
answer = str(q["Answer"]).strip()
elif "answer" in q:
answer = str(q["answer"]).strip()
labels.append(answer)
return questions, labels
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
"""Extract answer from model output."""
return extract_aime_answer(output)
def compute_accuracy(
self, predictions: List[Any], labels: List[Any]
) -> Optional[float]:
"""Compute accuracy for AIME by comparing numeric answers."""
if not labels or len(labels) == 0:
return None
if all(label is None for label in labels):
return None
correct = 0
valid_count = 0
for pred, label in zip(predictions, labels):
if label is not None:
valid_count += 1
if pred is not None:
# Normalize answers for comparison
pred_normalized = str(pred).strip()
label_normalized = str(label).strip()
# Try exact match first
if pred_normalized == label_normalized:
correct += 1
else:
# Try numeric comparison
try:
pred_num = int(pred_normalized)
label_num = int(label_normalized)
if pred_num == label_num:
correct += 1
except ValueError:
pass
return correct / valid_count if valid_count > 0 else 0.0
def create_sgl_function(self):
"""Create SGL function for AIME with reasoning prompt."""
return create_simple_sgl_function(
function_name="reasoning_gen",
answer_key="answer",
user_prefix="\nPlease reason step by step, and put your final answer within \\boxed{}.",
)
def get_max_new_tokens(self) -> int:
"""AIME problems require more tokens."""
return 32768
|