| import json |
| import re |
| from math_verify import parse, verify |
| from .grader import math_equal_process |
| from .math_equivalent_MATH import is_equiv |
| from .parse_utils_qwen import extract_answer as extract_fn |
|
|
|
|
| def extract_true_answer(text, name="gsm8k"): |
| ''' |
| Extract answer from text |
| |
| Args: |
| text: input text |
| name: name of the dataset |
| |
| Returns: |
| answer: extracted answer |
| ''' |
| if "gsm8k" in name.lower(): |
| label = text.split("#### ")[1] |
| return label |
| elif "asdiv-aug" in name.lower(): |
| label = text.split("####")[1] |
| return label |
| elif "math-500" in name.lower(): |
| return text |
| elif "aime" in name.lower(): |
| return text |
| elif "strategyqa" in name.lower(): |
| return text |
| elif "date_understanding" in name.lower(): |
| return text |
| elif "cruxeval" in name: |
| return text |
| else: |
| raise ValueError(f"Unknown dataset name: {name}") |
|
|
|
|
| def judge_answer(input, label, data_name="gsm8k", extract=True, prompt_idx=0): |
| """Score. |
| |
| Judge whether the answer is correct or not. |
| Only exact match is considered correct. |
| |
| Args: |
| input (str): model response |
| label (str): ground truth |
| data_name (str): name of the dataset, ["gsm8k", "MATH-500"] |
| extract (bool): whether to extract answer from model response |
| prompt_idx (int): index of the solver prompt (different format) |
| |
| Returns: |
| bool: True if the answer is correct, False otherwise |
| """ |
| if "gsm8k" in data_name.lower() or "asdiv-aug" in data_name.lower(): |
| if extract: |
| input = extract_answer(input, data_name="gsm8k", prompt_idx=prompt_idx) |
| return (input == label) |
| elif "math-500" in data_name.lower(): |
| if extract: |
| input = extract_answer(input, data_name="MATH-500", prompt_idx=prompt_idx) |
|
|
| |
| hf_input = parse(input) |
| hf_verifier_judge = verify(label, hf_input) |
| if hf_verifier_judge: |
| return True |
|
|
| |
| qwen_verifier_judge = math_equal_process((label, input)) |
| if qwen_verifier_judge: |
| return True |
|
|
| |
| exact_judge = (str(input) == str(label)) |
| if exact_judge: |
| return True |
|
|
| |
| MATH_500_judge = is_equiv(str(label), str(input)) |
| if MATH_500_judge: |
| return True |
| return False |
|
|
| elif "aime" in data_name.lower(): |
| if extract: |
| input = extract_answer(input, data_name="AIME_2024", prompt_idx=prompt_idx) |
| input = str(input) |
| label = str(label) |
| return (input == label) |
|
|
| elif "strategyqa" in data_name.lower(): |
| if extract: |
| input = extract_answer(input, data_name="strategyqa", prompt_idx=prompt_idx) |
| input = str(input).lower().strip() |
| label = str(label).lower().strip() |
| return (input == label) |
|
|
| elif "date_understanding" in data_name.lower(): |
| if extract: |
| input = extract_answer(input, data_name="date_understanding", prompt_idx=prompt_idx) |
| input = str(input).lower().strip() |
| label = str(label).lower().strip() |
| return (input == label) |
|
|
| elif "cruxeval" in data_name.lower(): |
| if extract: |
| input = extract_answer(input, data_name="cruxeval", prompt_idx=prompt_idx) |
| input = str(input) |
| label = str(label) |
| return (input == label) |
|
|
| else: |
| raise ValueError(f"Unknown dataset name: {data_name} for judge answer") |
|
|
|
|
| def extract_answer(text, data_name="gsm8k", prompt_idx=0, model_name="Qwen2.5-7B-Instruct"): |
| ''' |
| Extract answer from model response |
| |
| Args: |
| text: Raw response string from the language model |
| data_name: name of the dataset, ["gsm8k", "MATH-500"] |
| prompt_idx: index of the solver prompt (different format) |
| |
| Returns: |
| answer: extracted answer(pure numbers) |
| ''' |
| if "gsm8k" in data_name.lower() or "asdiv-aug" in data_name.lower(): |
| if prompt_idx == 0 or prompt_idx == 2: |
| |
| if "qwen2.5-1.5b-instruct" in model_name.lower(): |
| |
| temp = _extract_qwen25_1_5B_answer(text) |
| else: |
| temp = _extract_answer(text) |
| return temp |
|
|
| elif prompt_idx == 1: |
| |
| try: |
| answer = json.loads(text.strip('` \n')) |
| final_answer = answer.get('final answer', '') |
| if not isinstance(final_answer, str): |
| final_answer = str(final_answer) |
| temp = _extract_answer(final_answer) |
| return temp |
|
|
| except json.JSONDecodeError: |
| pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]' |
|
|
| match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL) |
| |
| if match: |
| temp = _extract_answer(match.group(1)) |
| return temp |
| else: |
| temp = _extract_answer(text) |
| return temp |
|
|
| else: |
| raise ValueError(f"Unknown prompt index: {prompt_idx} for extract answer") |
|
|
| elif "math-500" in data_name.lower(): |
| if prompt_idx == 0 or prompt_idx == 2: |
| |
| temp = extract_fn(text, data_name='math') |
| return temp |
|
|
| elif prompt_idx == 1: |
| |
| try: |
| answer = json.loads(text.strip('` \n')) |
| final_answer = answer.get('final answer', '') |
| if not isinstance(final_answer, str): |
| final_answer = str(final_answer) |
| final_answer = final_answer.replace("\n", "") |
| final_answer = final_answer.replace("\"", "") |
| final_answer = final_answer.replace("\'", "") |
| return final_answer |
|
|
| except json.JSONDecodeError: |
| text = text.replace("\n", "") |
| pattern = r'(?:final answer|my answer)"?:?\s*(.*?)(}<|<\|)' |
|
|
|
|
| match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL) |
| |
| if match: |
| temp = match.group(1) |
| temp = temp.replace("\n", "") |
| temp = temp.replace("\"", "") |
| temp = temp.replace("\'", "") |
| return temp |
| else: |
| return None |
|
|
| elif "aime" in data_name.lower() or "cruxeval" in data_name.lower(): |
| if prompt_idx == 0 or prompt_idx == 2: |
| |
| temp = _extract_answer(text) |
| return temp |
|
|
| elif prompt_idx == 1: |
| |
| try: |
| answer = json.loads(text.strip('` \n')) |
| final_answer = answer.get('final answer', '') |
| if not isinstance(final_answer, str): |
| final_answer = str(final_answer) |
| temp = _extract_answer(final_answer) |
| return temp |
|
|
| except json.JSONDecodeError: |
| pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]' |
|
|
| match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL) |
| |
| if match: |
| temp = _extract_answer(match.group(1)) |
| return temp |
| else: |
| temp = _extract_answer(text) |
| return temp |
|
|
| else: |
| raise ValueError(f"Unknown prompt index: {prompt_idx} for extract answer") |
|
|
| elif "date_understanding" in data_name.lower(): |
| if prompt_idx == 0 or prompt_idx == 2: |
| |
| temp = _extract_option_answer(text) |
| return temp |
|
|
| elif prompt_idx == 1: |
| |
| try: |
| answer = json.loads(text.strip('` \n')) |
| final_answer = answer.get('final answer', '') |
| if not isinstance(final_answer, str): |
| final_answer = str(final_answer) |
| temp = _extract_option_answer(final_answer) |
| return temp |
|
|
| except json.JSONDecodeError: |
| pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]' |
|
|
| match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL) |
| if match: |
| temp = _extract_option_answer(match.group(1)) |
| return temp |
| else: |
| temp = _extract_option_answer(text) |
| return temp |
|
|
| elif "strategyqa" in data_name.lower(): |
| if prompt_idx == 0 or prompt_idx == 2: |
| |
| temp = _extract_bool_answer(text) |
| return temp |
|
|
| elif prompt_idx == 1: |
| |
| try: |
| answer = json.loads(text.strip('` \n')) |
| final_answer = answer.get('final answer', '') |
| if not isinstance(final_answer, str): |
| final_answer = str(final_answer) |
| temp = _extract_bool_answer(final_answer) |
| return temp |
|
|
| except json.JSONDecodeError: |
| pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]' |
|
|
| match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL) |
| if match: |
| temp = _extract_bool_answer(match.group(1)) |
| return temp |
| else: |
| temp = _extract_bool_answer(text) |
| return temp |
|
|
| else: |
| raise ValueError(f"Unknown prompt index: {prompt_idx} for extract answer") |
|
|
| else: |
| raise ValueError(f"Unknown dataset name: {data_name} for extract answer") |
|
|
|
|
| def _extract_bool_answer(text: str) -> bool | None: |
| last_yes = re.search(r'\bsey\b', text.lower()[::-1]) |
| if last_yes is not None: |
| last_yes = last_yes.start() |
| else: |
| last_yes = len(text) |
| last_no = re.search(r'\bon\b', text.lower()[::-1]) |
| if last_no is not None: |
| last_no = last_no.start() |
| else: |
| last_no = len(text) |
| if last_yes == last_no == len(text): |
| return None |
| return last_yes < last_no |
|
|
| def _extract_option_answer(text: str) -> str | None: |
| def clean_option(opt_str): |
| match = re.search(r'[a-f]', opt_str.lower()[::-1]) |
| return match.group(0).upper() if match else None |
|
|
| |
| |
| boxed_pattern = r"\\boxed\{\s*(.*)\s*\}" |
| all_matches = list(re.finditer(boxed_pattern, text, re.IGNORECASE)) |
| if all_matches: |
| return clean_option(all_matches[-1].group(1)) |
|
|
| |
| answer_pattern = r"he answer is\s*(.*)" |
| all_matches = list(re.finditer(boxed_pattern, text, re.IGNORECASE)) |
| if all_matches: |
| return clean_option(all_matches[-1].group(1)) |
|
|
| |
| answer_pattern = r"final answer is\s*(.*)" |
| all_matches = list(re.finditer(boxed_pattern, text, re.IGNORECASE)) |
| if all_matches: |
| return clean_option(all_matches[-1].group(1)) |
|
|
| return None |
|
|
|
|
| |
| |
| |
|
|
| def extract_MATH_solution(solution_str: str): |
| """Extracts the final answer from the model's response string. |
| |
| Args: |
| solution_str: Raw response string from the language model |
| |
| Returns: |
| extracted final answer |
| """"" |
| |
| if "Assistant:" in solution_str: |
| processed_str = solution_str.split("Assistant:", 1)[1] |
| elif "<|im_start|>assistant" in solution_str: |
| processed_str = solution_str.split("<|im_start|>assistant", 1)[1] |
| else: |
| processed_str = solution_str |
|
|
| |
| answer_pattern = r'<answer>.*?(\\boxed{.*}).*?</answer>' |
| matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL)) |
|
|
| if not matches: |
| answer_pattern = r'\\boxed{(.*)}' |
| matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL)) |
| if not matches: |
| print("[Error] No valid answer tags found") |
| return None |
| final_answer = matches[-1].group(1).strip() |
| return final_answer |
|
|
|
|
| def _extract_answer(text): |
| """ |
| Extract numerical answer from generated text. |
| handling various edge cases. |
| |
| Args: |
| text (str): Generated text to extract answer from. |
| |
| Returns: |
| str or None: Extracted numerical answer, or None if not found. |
| """ |
| if text is None: |
| return None |
| |
| text = text.strip() |
|
|
| def clean_number(num_str): |
| """Remove currency symbols, commas, and whitespace.""" |
| num_str = re.sub(r'[$€£¥]', '', num_str) |
| num_str = re.sub(r',', '', num_str) |
| num_str = re.sub(r'\s', '', num_str) |
| return num_str |
|
|
| |
| |
| boxed_pattern = r"\\boxed\{\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*\}" |
| match = re.search(boxed_pattern, text, re.IGNORECASE) |
| if match: |
| return clean_number(match.group(1)) |
| |
| |
| answer_pattern = r"Answer:\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)" |
| match = re.search(answer_pattern, text, re.IGNORECASE) |
| if match: |
| return clean_number(match.group(1)) |
| |
| |
| equals_pattern = r"=\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)" |
| match = re.search(equals_pattern, text) |
| if match: |
| return clean_number(match.group(1)) |
|
|
| |
| currency_pattern = r"is\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*(?:dollars|euros|pounds|yen)" |
| match = re.search(currency_pattern, text, re.IGNORECASE) |
| if match: |
| return clean_number(match.group(1)) |
|
|
| |
| lines = text.split('\n') |
| for line in reversed(lines): |
| line = line.strip() |
| if line: |
| final_num_pattern = r"([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*$" |
| match = re.search(final_num_pattern, line) |
| if match: |
| return clean_number(match.group(1)) |
|
|
| |
| number_pattern = r"([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)" |
| matches = re.findall(number_pattern, text) |
| if matches: |
| return clean_number(matches[-1]) |
|
|
| return None |
|
|
|
|
| def _extract_qwen25_1_5B_answer(text): |
| """ |
| Extract numerical answer from generated text for Qwen-2.5 1.5B model. |
| handling various edge cases. |
| |
| Args: |
| text (str): Generated text to extract answer from. |
| |
| Returns: |
| str or None: Extracted numerical answer, or None if not found. |
| """ |
| if text is None: |
| return None |
|
|
| text = text.strip() |
|
|
| def clean_number(num_str): |
| """Remove currency symbols, commas, and whitespace.""" |
| num_str = re.sub(r'[$€£¥]', '', num_str) |
| num_str = re.sub(r',', '', num_str) |
| num_str = re.sub(r'\s', '', num_str) |
| return num_str |
|
|
| |
| |
| boxed_pattern = r"\\boxed\{\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*\}" |
| match = re.search(boxed_pattern, text, re.IGNORECASE) |
| if match: |
| return clean_number(match.group(1)) |
|
|
| |
| answer_pattern = r"he answer is\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)" |
| match = re.search(answer_pattern, text, re.IGNORECASE) |
| if match: |
| return clean_number(match.group(1)) |
|
|
| |
| answer_pattern = r"final answer is\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)" |
| match = re.search(answer_pattern, text, re.IGNORECASE) |
| if match: |
| return clean_number(match.group(1)) |
|
|
| |
| number_pattern = r'\d+(?:,\d+)*(?:\.\d+)?' |
| matches = re.findall(number_pattern, text) |
| if matches: |
| return clean_number(matches[-1]) |
|
|
| return None |
|
|