yfan07's picture
Add files using upload-large-folder tool
2fdf3c9 verified
import json
import re
from math_verify import parse, verify
from .grader import math_equal_process
from .math_equivalent_MATH import is_equiv
from .parse_utils_qwen import extract_answer as extract_fn
def extract_true_answer(text, name="gsm8k"):
'''
Extract answer from text
Args:
text: input text
name: name of the dataset
Returns:
answer: extracted answer
'''
if "gsm8k" in name.lower():
label = text.split("#### ")[1]
return label
elif "asdiv-aug" in name.lower():
label = text.split("####")[1]
return label
elif "math-500" in name.lower():
return text
elif "aime" in name.lower():
return text
elif "strategyqa" in name.lower():
return text
elif "date_understanding" in name.lower():
return text
elif "cruxeval" in name:
return text
else:
raise ValueError(f"Unknown dataset name: {name}")
def judge_answer(input, label, data_name="gsm8k", extract=True, prompt_idx=0):
"""Score.
Judge whether the answer is correct or not.
Only exact match is considered correct.
Args:
input (str): model response
label (str): ground truth
data_name (str): name of the dataset, ["gsm8k", "MATH-500"]
extract (bool): whether to extract answer from model response
prompt_idx (int): index of the solver prompt (different format)
Returns:
bool: True if the answer is correct, False otherwise
"""
if "gsm8k" in data_name.lower() or "asdiv-aug" in data_name.lower():
if extract:
input = extract_answer(input, data_name="gsm8k", prompt_idx=prompt_idx)
return (input == label)
elif "math-500" in data_name.lower():
if extract:
input = extract_answer(input, data_name="MATH-500", prompt_idx=prompt_idx)
# huggingface math_verify
hf_input = parse(input)
hf_verifier_judge = verify(label, hf_input)
if hf_verifier_judge:
return True
# qwen2.5-math
qwen_verifier_judge = math_equal_process((label, input))
if qwen_verifier_judge:
return True
# exact match
exact_judge = (str(input) == str(label))
if exact_judge:
return True
# MATH-500
MATH_500_judge = is_equiv(str(label), str(input))
if MATH_500_judge:
return True
return False
elif "aime" in data_name.lower():
if extract:
input = extract_answer(input, data_name="AIME_2024", prompt_idx=prompt_idx)
input = str(input)
label = str(label)
return (input == label)
elif "strategyqa" in data_name.lower():
if extract:
input = extract_answer(input, data_name="strategyqa", prompt_idx=prompt_idx)
input = str(input).lower().strip()
label = str(label).lower().strip()
return (input == label)
elif "date_understanding" in data_name.lower():
if extract:
input = extract_answer(input, data_name="date_understanding", prompt_idx=prompt_idx)
input = str(input).lower().strip()
label = str(label).lower().strip()
return (input == label)
elif "cruxeval" in data_name.lower():
if extract:
input = extract_answer(input, data_name="cruxeval", prompt_idx=prompt_idx)
input = str(input)
label = str(label)
return (input == label)
else:
raise ValueError(f"Unknown dataset name: {data_name} for judge answer")
def extract_answer(text, data_name="gsm8k", prompt_idx=0, model_name="Qwen2.5-7B-Instruct"):
'''
Extract answer from model response
Args:
text: Raw response string from the language model
data_name: name of the dataset, ["gsm8k", "MATH-500"]
prompt_idx: index of the solver prompt (different format)
Returns:
answer: extracted answer(pure numbers)
'''
if "gsm8k" in data_name.lower() or "asdiv-aug" in data_name.lower():
if prompt_idx == 0 or prompt_idx == 2:
# 0: boxed
if "qwen2.5-1.5b-instruct" in model_name.lower():
# well, well, well
temp = _extract_qwen25_1_5B_answer(text)
else:
temp = _extract_answer(text)
return temp
elif prompt_idx == 1:
# 1: json
try:
answer = json.loads(text.strip('` \n'))
final_answer = answer.get('final answer', '')
if not isinstance(final_answer, str):
final_answer = str(final_answer)
temp = _extract_answer(final_answer)
return temp
except json.JSONDecodeError:
pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]'
match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL)
if match:
temp = _extract_answer(match.group(1))
return temp
else:
temp = _extract_answer(text)
return temp
else:
raise ValueError(f"Unknown prompt index: {prompt_idx} for extract answer")
elif "math-500" in data_name.lower():
if prompt_idx == 0 or prompt_idx == 2:
# 0: boxed
temp = extract_fn(text, data_name='math')
return temp
elif prompt_idx == 1:
# json
try:
answer = json.loads(text.strip('` \n'))
final_answer = answer.get('final answer', '')
if not isinstance(final_answer, str):
final_answer = str(final_answer)
final_answer = final_answer.replace("\n", "")
final_answer = final_answer.replace("\"", "")
final_answer = final_answer.replace("\'", "")
return final_answer
except json.JSONDecodeError:
text = text.replace("\n", "")
pattern = r'(?:final answer|my answer)"?:?\s*(.*?)(}<|<\|)'
match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL)
if match:
temp = match.group(1)
temp = temp.replace("\n", "")
temp = temp.replace("\"", "")
temp = temp.replace("\'", "")
return temp
else:
return None
elif "aime" in data_name.lower() or "cruxeval" in data_name.lower():
if prompt_idx == 0 or prompt_idx == 2:
# 0: boxed
temp = _extract_answer(text)
return temp
elif prompt_idx == 1:
# 1: json, {"final answer": ...}
try:
answer = json.loads(text.strip('` \n'))
final_answer = answer.get('final answer', '')
if not isinstance(final_answer, str):
final_answer = str(final_answer)
temp = _extract_answer(final_answer)
return temp
except json.JSONDecodeError:
pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]'
match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL)
if match:
temp = _extract_answer(match.group(1))
return temp
else:
temp = _extract_answer(text)
return temp
else:
raise ValueError(f"Unknown prompt index: {prompt_idx} for extract answer")
elif "date_understanding" in data_name.lower():
if prompt_idx == 0 or prompt_idx == 2:
# 0: boxed
temp = _extract_option_answer(text)
return temp
elif prompt_idx == 1:
# 1: json, {"final answer": ...}
try:
answer = json.loads(text.strip('` \n'))
final_answer = answer.get('final answer', '')
if not isinstance(final_answer, str):
final_answer = str(final_answer)
temp = _extract_option_answer(final_answer)
return temp
except json.JSONDecodeError:
pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]'
match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL)
if match:
temp = _extract_option_answer(match.group(1))
return temp
else:
temp = _extract_option_answer(text)
return temp
elif "strategyqa" in data_name.lower():
if prompt_idx == 0 or prompt_idx == 2:
# 0: boxed
temp = _extract_bool_answer(text)
return temp
elif prompt_idx == 1:
# 1: json, {"final answer": ...}
try:
answer = json.loads(text.strip('` \n'))
final_answer = answer.get('final answer', '')
if not isinstance(final_answer, str):
final_answer = str(final_answer)
temp = _extract_bool_answer(final_answer)
return temp
except json.JSONDecodeError:
pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]'
match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL)
if match:
temp = _extract_bool_answer(match.group(1))
return temp
else:
temp = _extract_bool_answer(text)
return temp
else:
raise ValueError(f"Unknown prompt index: {prompt_idx} for extract answer")
else:
raise ValueError(f"Unknown dataset name: {data_name} for extract answer")
def _extract_bool_answer(text: str) -> bool | None:
last_yes = re.search(r'\bsey\b', text.lower()[::-1])
if last_yes is not None:
last_yes = last_yes.start()
else:
last_yes = len(text)
last_no = re.search(r'\bon\b', text.lower()[::-1])
if last_no is not None:
last_no = last_no.start()
else:
last_no = len(text)
if last_yes == last_no == len(text):
return None
return last_yes < last_no
def _extract_option_answer(text: str) -> str | None:
def clean_option(opt_str):
match = re.search(r'[a-f]', opt_str.lower()[::-1])
return match.group(0).upper() if match else None
### Several Corner Cases ###
# 1. \boxed{}
boxed_pattern = r"\\boxed\{\s*(.*)\s*\}"
all_matches = list(re.finditer(boxed_pattern, text, re.IGNORECASE))
if all_matches:
return clean_option(all_matches[-1].group(1))
# 2. he answer is
answer_pattern = r"he answer is\s*(.*)"
all_matches = list(re.finditer(boxed_pattern, text, re.IGNORECASE))
if all_matches:
return clean_option(all_matches[-1].group(1))
# 3. final answer is
answer_pattern = r"final answer is\s*(.*)"
all_matches = list(re.finditer(boxed_pattern, text, re.IGNORECASE))
if all_matches:
return clean_option(all_matches[-1].group(1))
return None
######################
# MATH #
######################
def extract_MATH_solution(solution_str: str):
"""Extracts the final answer from the model's response string.
Args:
solution_str: Raw response string from the language model
Returns:
extracted final answer
"""""
# Split response to isolate assistant output
if "Assistant:" in solution_str:
processed_str = solution_str.split("Assistant:", 1)[1]
elif "<|im_start|>assistant" in solution_str:
processed_str = solution_str.split("<|im_start|>assistant", 1)[1]
else:
processed_str = solution_str
# Extract final answer using XML-style tags
answer_pattern = r'<answer>.*?(\\boxed{.*}).*?</answer>'
matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL))
if not matches:
answer_pattern = r'\\boxed{(.*)}'
matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL))
if not matches:
print("[Error] No valid answer tags found")
return None
final_answer = matches[-1].group(1).strip()
return final_answer
def _extract_answer(text):
"""
Extract numerical answer from generated text.
handling various edge cases.
Args:
text (str): Generated text to extract answer from.
Returns:
str or None: Extracted numerical answer, or None if not found.
"""
if text is None:
return None
text = text.strip()
def clean_number(num_str):
"""Remove currency symbols, commas, and whitespace."""
num_str = re.sub(r'[$€£¥]', '', num_str)
num_str = re.sub(r',', '', num_str)
num_str = re.sub(r'\s', '', num_str)
return num_str
### Several Corner Cases ###
# 1. \boxed{}
boxed_pattern = r"\\boxed\{\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*\}"
match = re.search(boxed_pattern, text, re.IGNORECASE)
if match:
return clean_number(match.group(1))
# 2. Answer:
answer_pattern = r"Answer:\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)"
match = re.search(answer_pattern, text, re.IGNORECASE)
if match:
return clean_number(match.group(1))
# 3. =
equals_pattern = r"=\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)"
match = re.search(equals_pattern, text)
if match:
return clean_number(match.group(1))
# 4. With currency unit
currency_pattern = r"is\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*(?:dollars|euros|pounds|yen)"
match = re.search(currency_pattern, text, re.IGNORECASE)
if match:
return clean_number(match.group(1))
# 5. Search from the last line of the text upwards, matching independent numbers
lines = text.split('\n')
for line in reversed(lines):
line = line.strip()
if line:
final_num_pattern = r"([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*$"
match = re.search(final_num_pattern, line)
if match:
return clean_number(match.group(1))
# 6. Returns the last matching number in the text
number_pattern = r"([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)"
matches = re.findall(number_pattern, text)
if matches:
return clean_number(matches[-1])
return None
def _extract_qwen25_1_5B_answer(text):
"""
Extract numerical answer from generated text for Qwen-2.5 1.5B model.
handling various edge cases.
Args:
text (str): Generated text to extract answer from.
Returns:
str or None: Extracted numerical answer, or None if not found.
"""
if text is None:
return None
text = text.strip()
def clean_number(num_str):
"""Remove currency symbols, commas, and whitespace."""
num_str = re.sub(r'[$€£¥]', '', num_str)
num_str = re.sub(r',', '', num_str)
num_str = re.sub(r'\s', '', num_str)
return num_str
### Several Corner Cases ###
# 1. \boxed{}
boxed_pattern = r"\\boxed\{\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*\}"
match = re.search(boxed_pattern, text, re.IGNORECASE)
if match:
return clean_number(match.group(1))
# 2. he answer is
answer_pattern = r"he answer is\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)"
match = re.search(answer_pattern, text, re.IGNORECASE)
if match:
return clean_number(match.group(1))
# 3. final answer is
answer_pattern = r"final answer is\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)"
match = re.search(answer_pattern, text, re.IGNORECASE)
if match:
return clean_number(match.group(1))
# 4. Returns the last matching number in the text
number_pattern = r'\d+(?:,\d+)*(?:\.\d+)?'
matches = re.findall(number_pattern, text)
if matches:
return clean_number(matches[-1])
return None