import json import re from math_verify import parse, verify from .grader import math_equal_process from .math_equivalent_MATH import is_equiv from .parse_utils_qwen import extract_answer as extract_fn def extract_true_answer(text, name="gsm8k"): ''' Extract answer from text Args: text: input text name: name of the dataset Returns: answer: extracted answer ''' if "gsm8k" in name.lower(): label = text.split("#### ")[1] return label elif "asdiv-aug" in name.lower(): label = text.split("####")[1] return label elif "math-500" in name.lower(): return text elif "aime" in name.lower(): return text elif "strategyqa" in name.lower(): return text elif "date_understanding" in name.lower(): return text elif "cruxeval" in name: return text else: raise ValueError(f"Unknown dataset name: {name}") def judge_answer(input, label, data_name="gsm8k", extract=True, prompt_idx=0): """Score. Judge whether the answer is correct or not. Only exact match is considered correct. Args: input (str): model response label (str): ground truth data_name (str): name of the dataset, ["gsm8k", "MATH-500"] extract (bool): whether to extract answer from model response prompt_idx (int): index of the solver prompt (different format) Returns: bool: True if the answer is correct, False otherwise """ if "gsm8k" in data_name.lower() or "asdiv-aug" in data_name.lower(): if extract: input = extract_answer(input, data_name="gsm8k", prompt_idx=prompt_idx) return (input == label) elif "math-500" in data_name.lower(): if extract: input = extract_answer(input, data_name="MATH-500", prompt_idx=prompt_idx) # huggingface math_verify hf_input = parse(input) hf_verifier_judge = verify(label, hf_input) if hf_verifier_judge: return True # qwen2.5-math qwen_verifier_judge = math_equal_process((label, input)) if qwen_verifier_judge: return True # exact match exact_judge = (str(input) == str(label)) if exact_judge: return True # MATH-500 MATH_500_judge = is_equiv(str(label), str(input)) if MATH_500_judge: return True return False elif "aime" in data_name.lower(): if extract: input = extract_answer(input, data_name="AIME_2024", prompt_idx=prompt_idx) input = str(input) label = str(label) return (input == label) elif "strategyqa" in data_name.lower(): if extract: input = extract_answer(input, data_name="strategyqa", prompt_idx=prompt_idx) input = str(input).lower().strip() label = str(label).lower().strip() return (input == label) elif "date_understanding" in data_name.lower(): if extract: input = extract_answer(input, data_name="date_understanding", prompt_idx=prompt_idx) input = str(input).lower().strip() label = str(label).lower().strip() return (input == label) elif "cruxeval" in data_name.lower(): if extract: input = extract_answer(input, data_name="cruxeval", prompt_idx=prompt_idx) input = str(input) label = str(label) return (input == label) else: raise ValueError(f"Unknown dataset name: {data_name} for judge answer") def extract_answer(text, data_name="gsm8k", prompt_idx=0, model_name="Qwen2.5-7B-Instruct"): ''' Extract answer from model response Args: text: Raw response string from the language model data_name: name of the dataset, ["gsm8k", "MATH-500"] prompt_idx: index of the solver prompt (different format) Returns: answer: extracted answer(pure numbers) ''' if "gsm8k" in data_name.lower() or "asdiv-aug" in data_name.lower(): if prompt_idx == 0 or prompt_idx == 2: # 0: boxed if "qwen2.5-1.5b-instruct" in model_name.lower(): # well, well, well temp = _extract_qwen25_1_5B_answer(text) else: temp = _extract_answer(text) return temp elif prompt_idx == 1: # 1: json try: answer = json.loads(text.strip('` \n')) final_answer = answer.get('final answer', '') if not isinstance(final_answer, str): final_answer = str(final_answer) temp = _extract_answer(final_answer) return temp except json.JSONDecodeError: pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]' match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL) if match: temp = _extract_answer(match.group(1)) return temp else: temp = _extract_answer(text) return temp else: raise ValueError(f"Unknown prompt index: {prompt_idx} for extract answer") elif "math-500" in data_name.lower(): if prompt_idx == 0 or prompt_idx == 2: # 0: boxed temp = extract_fn(text, data_name='math') return temp elif prompt_idx == 1: # json try: answer = json.loads(text.strip('` \n')) final_answer = answer.get('final answer', '') if not isinstance(final_answer, str): final_answer = str(final_answer) final_answer = final_answer.replace("\n", "") final_answer = final_answer.replace("\"", "") final_answer = final_answer.replace("\'", "") return final_answer except json.JSONDecodeError: text = text.replace("\n", "") pattern = r'(?:final answer|my answer)"?:?\s*(.*?)(}<|<\|)' match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL) if match: temp = match.group(1) temp = temp.replace("\n", "") temp = temp.replace("\"", "") temp = temp.replace("\'", "") return temp else: return None elif "aime" in data_name.lower() or "cruxeval" in data_name.lower(): if prompt_idx == 0 or prompt_idx == 2: # 0: boxed temp = _extract_answer(text) return temp elif prompt_idx == 1: # 1: json, {"final answer": ...} try: answer = json.loads(text.strip('` \n')) final_answer = answer.get('final answer', '') if not isinstance(final_answer, str): final_answer = str(final_answer) temp = _extract_answer(final_answer) return temp except json.JSONDecodeError: pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]' match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL) if match: temp = _extract_answer(match.group(1)) return temp else: temp = _extract_answer(text) return temp else: raise ValueError(f"Unknown prompt index: {prompt_idx} for extract answer") elif "date_understanding" in data_name.lower(): if prompt_idx == 0 or prompt_idx == 2: # 0: boxed temp = _extract_option_answer(text) return temp elif prompt_idx == 1: # 1: json, {"final answer": ...} try: answer = json.loads(text.strip('` \n')) final_answer = answer.get('final answer', '') if not isinstance(final_answer, str): final_answer = str(final_answer) temp = _extract_option_answer(final_answer) return temp except json.JSONDecodeError: pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]' match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL) if match: temp = _extract_option_answer(match.group(1)) return temp else: temp = _extract_option_answer(text) return temp elif "strategyqa" in data_name.lower(): if prompt_idx == 0 or prompt_idx == 2: # 0: boxed temp = _extract_bool_answer(text) return temp elif prompt_idx == 1: # 1: json, {"final answer": ...} try: answer = json.loads(text.strip('` \n')) final_answer = answer.get('final answer', '') if not isinstance(final_answer, str): final_answer = str(final_answer) temp = _extract_bool_answer(final_answer) return temp except json.JSONDecodeError: pattern = r'(?:final answer|my answer)"?:?\s*(.*?)[}<]' match = re.search(pattern, text, flags=re.I | re.M | re.DOTALL) if match: temp = _extract_bool_answer(match.group(1)) return temp else: temp = _extract_bool_answer(text) return temp else: raise ValueError(f"Unknown prompt index: {prompt_idx} for extract answer") else: raise ValueError(f"Unknown dataset name: {data_name} for extract answer") def _extract_bool_answer(text: str) -> bool | None: last_yes = re.search(r'\bsey\b', text.lower()[::-1]) if last_yes is not None: last_yes = last_yes.start() else: last_yes = len(text) last_no = re.search(r'\bon\b', text.lower()[::-1]) if last_no is not None: last_no = last_no.start() else: last_no = len(text) if last_yes == last_no == len(text): return None return last_yes < last_no def _extract_option_answer(text: str) -> str | None: def clean_option(opt_str): match = re.search(r'[a-f]', opt_str.lower()[::-1]) return match.group(0).upper() if match else None ### Several Corner Cases ### # 1. \boxed{} boxed_pattern = r"\\boxed\{\s*(.*)\s*\}" all_matches = list(re.finditer(boxed_pattern, text, re.IGNORECASE)) if all_matches: return clean_option(all_matches[-1].group(1)) # 2. he answer is answer_pattern = r"he answer is\s*(.*)" all_matches = list(re.finditer(boxed_pattern, text, re.IGNORECASE)) if all_matches: return clean_option(all_matches[-1].group(1)) # 3. final answer is answer_pattern = r"final answer is\s*(.*)" all_matches = list(re.finditer(boxed_pattern, text, re.IGNORECASE)) if all_matches: return clean_option(all_matches[-1].group(1)) return None ###################### # MATH # ###################### def extract_MATH_solution(solution_str: str): """Extracts the final answer from the model's response string. Args: solution_str: Raw response string from the language model Returns: extracted final answer """"" # Split response to isolate assistant output if "Assistant:" in solution_str: processed_str = solution_str.split("Assistant:", 1)[1] elif "<|im_start|>assistant" in solution_str: processed_str = solution_str.split("<|im_start|>assistant", 1)[1] else: processed_str = solution_str # Extract final answer using XML-style tags answer_pattern = r'.*?(\\boxed{.*}).*?' matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL)) if not matches: answer_pattern = r'\\boxed{(.*)}' matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL)) if not matches: print("[Error] No valid answer tags found") return None final_answer = matches[-1].group(1).strip() return final_answer def _extract_answer(text): """ Extract numerical answer from generated text. handling various edge cases. Args: text (str): Generated text to extract answer from. Returns: str or None: Extracted numerical answer, or None if not found. """ if text is None: return None text = text.strip() def clean_number(num_str): """Remove currency symbols, commas, and whitespace.""" num_str = re.sub(r'[$€£¥]', '', num_str) num_str = re.sub(r',', '', num_str) num_str = re.sub(r'\s', '', num_str) return num_str ### Several Corner Cases ### # 1. \boxed{} boxed_pattern = r"\\boxed\{\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*\}" match = re.search(boxed_pattern, text, re.IGNORECASE) if match: return clean_number(match.group(1)) # 2. Answer: answer_pattern = r"Answer:\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)" match = re.search(answer_pattern, text, re.IGNORECASE) if match: return clean_number(match.group(1)) # 3. = equals_pattern = r"=\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)" match = re.search(equals_pattern, text) if match: return clean_number(match.group(1)) # 4. With currency unit currency_pattern = r"is\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*(?:dollars|euros|pounds|yen)" match = re.search(currency_pattern, text, re.IGNORECASE) if match: return clean_number(match.group(1)) # 5. Search from the last line of the text upwards, matching independent numbers lines = text.split('\n') for line in reversed(lines): line = line.strip() if line: final_num_pattern = r"([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*$" match = re.search(final_num_pattern, line) if match: return clean_number(match.group(1)) # 6. Returns the last matching number in the text number_pattern = r"([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)" matches = re.findall(number_pattern, text) if matches: return clean_number(matches[-1]) return None def _extract_qwen25_1_5B_answer(text): """ Extract numerical answer from generated text for Qwen-2.5 1.5B model. handling various edge cases. Args: text (str): Generated text to extract answer from. Returns: str or None: Extracted numerical answer, or None if not found. """ if text is None: return None text = text.strip() def clean_number(num_str): """Remove currency symbols, commas, and whitespace.""" num_str = re.sub(r'[$€£¥]', '', num_str) num_str = re.sub(r',', '', num_str) num_str = re.sub(r'\s', '', num_str) return num_str ### Several Corner Cases ### # 1. \boxed{} boxed_pattern = r"\\boxed\{\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)\s*\}" match = re.search(boxed_pattern, text, re.IGNORECASE) if match: return clean_number(match.group(1)) # 2. he answer is answer_pattern = r"he answer is\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)" match = re.search(answer_pattern, text, re.IGNORECASE) if match: return clean_number(match.group(1)) # 3. final answer is answer_pattern = r"final answer is\s*([$€£¥]?\s*-?\s*[\d,]+(?:\.\d+)?)" match = re.search(answer_pattern, text, re.IGNORECASE) if match: return clean_number(match.group(1)) # 4. Returns the last matching number in the text number_pattern = r'\d+(?:,\d+)*(?:\.\d+)?' matches = re.findall(number_pattern, text) if matches: return clean_number(matches[-1]) return None