| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import re |
| | from typing import Any, Dict, List |
| |
|
| |
|
| | |
| | SUBSTITUTIONS = [ |
| | ("an ", ""), |
| | ("a ", ""), |
| | (".$", "$"), |
| | ("\\$", ""), |
| | (r"\ ", ""), |
| | (" ", ""), |
| | ("mbox", "text"), |
| | (",\\text{and}", ","), |
| | ("\\text{and}", ","), |
| | ("\\text{m}", "\\text{}"), |
| | ] |
| |
|
| | REMOVED_EXPRESSIONS = [ |
| | "square", |
| | "ways", |
| | "integers", |
| | "dollars", |
| | "mph", |
| | "inches", |
| | "hours", |
| | "km", |
| | "units", |
| | "\\ldots", |
| | "sue", |
| | "points", |
| | "feet", |
| | "minutes", |
| | "digits", |
| | "cents", |
| | "degrees", |
| | "cm", |
| | "gm", |
| | "pounds", |
| | "meters", |
| | "meals", |
| | "edges", |
| | "students", |
| | "childrentickets", |
| | "multiples", |
| | "\\text{s}", |
| | "\\text{.}", |
| | "\\text{\ns}", |
| | "\\text{}^2", |
| | "\\text{}^3", |
| | "\\text{\n}", |
| | "\\text{}", |
| | r"\mathrm{th}", |
| | r"^\circ", |
| | r"^{\circ}", |
| | r"\;", |
| | r",\!", |
| | "{,}", |
| | '"', |
| | "\\dots", |
| | ] |
| |
|
| |
|
| | def normalize_final_answer(final_answer: str) -> str: |
| | """Normalize a final answer to a quantitative reasoning question. |
| | |
| | Args: |
| | final_answer: The answer string to normalize |
| | |
| | Returns: |
| | Normalized answer string |
| | """ |
| | final_answer = final_answer.split("=")[-1] |
| |
|
| | |
| | for before, after in SUBSTITUTIONS: |
| | final_answer = final_answer.replace(before, after) |
| | for expr in REMOVED_EXPRESSIONS: |
| | final_answer = final_answer.replace(expr, "") |
| |
|
| | |
| | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) |
| | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) |
| | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) |
| | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) |
| | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) |
| | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) |
| | final_answer = final_answer.replace("$", "") |
| |
|
| | |
| | if final_answer.replace(",", "").isdigit(): |
| | final_answer = final_answer.replace(",", "") |
| |
|
| | return final_answer.strip() |
| |
|
| |
|
| | def accuracy_reward(response: str, ground_truth: str) -> float: |
| | match = re.findall(r"(?i)Answer\s*:\s*([^\n]+)", response) |
| | answer = match[-1] if match else "[INVALID]" |
| | if normalize_final_answer(answer) == normalize_final_answer(ground_truth): |
| | return 1.0 |
| | else: |
| | return -1.0 |
| |
|
| |
|
| | def soft_overlong_punishment(response_length: int, max_response_length: int, overlong_buffer_length: int): |
| | expected_len = max_response_length - overlong_buffer_length |
| | if response_length <= expected_len: |
| | return 0.0 |
| | elif response_length <= max_response_length: |
| | return (expected_len - response_length) / overlong_buffer_length |
| | else: |
| | return -1.0 |
| |
|
| |
|
| | def compute_score( |
| | reward_inputs: List[Dict[str, Any]], |
| | max_response_length: int, |
| | overlong_buffer_length: int, |
| | overlong_penalty_factor: float, |
| | ) -> List[Dict[str, float]]: |
| | if not isinstance(reward_inputs, list): |
| | raise ValueError("Please use `reward_type=batch` for dapo reward function.") |
| |
|
| | scores = [] |
| | for reward_input in reward_inputs: |
| | response = reward_input["response"][-300:] |
| | accuracy_score = accuracy_reward(response, reward_input["ground_truth"]) |
| | overlong_score = soft_overlong_punishment( |
| | reward_input["response_length"], max_response_length, overlong_buffer_length |
| | ) |
| | scores.append( |
| | { |
| | "overall": accuracy_score + overlong_score * overlong_penalty_factor, |
| | "accuracy": accuracy_score, |
| | "overlong": overlong_score, |
| | "accuracy_normalized": 0.5 * (accuracy_score + 1.0), |
| | } |
| | ) |
| |
|
| | return scores |
| |
|