yuccaaa's picture
Add files using upload-large-folder tool
9440cb3 verified
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any, Dict, List
# Constants for normalization
SUBSTITUTIONS = [
("an ", ""),
("a ", ""),
(".$", "$"),
("\\$", ""),
(r"\ ", ""),
(" ", ""),
("mbox", "text"),
(",\\text{and}", ","),
("\\text{and}", ","),
("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
"square",
"ways",
"integers",
"dollars",
"mph",
"inches",
"hours",
"km",
"units",
"\\ldots",
"sue",
"points",
"feet",
"minutes",
"digits",
"cents",
"degrees",
"cm",
"gm",
"pounds",
"meters",
"meals",
"edges",
"students",
"childrentickets",
"multiples",
"\\text{s}",
"\\text{.}",
"\\text{\ns}",
"\\text{}^2",
"\\text{}^3",
"\\text{\n}",
"\\text{}",
r"\mathrm{th}",
r"^\circ",
r"^{\circ}",
r"\;",
r",\!",
"{,}",
'"',
"\\dots",
]
def normalize_final_answer(final_answer: str) -> str:
"""Normalize a final answer to a quantitative reasoning question.
Args:
final_answer: The answer string to normalize
Returns:
Normalized answer string
"""
final_answer = final_answer.split("=")[-1]
# Apply substitutions and removals
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, "")
# Extract and normalize LaTeX math
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
# Normalize shorthand TeX:
# \fracab -> \frac{a}{b}
# \frac{abc}{bef} -> \frac{abc}{bef}
# \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
final_answer = final_answer.replace("$", "")
# Normalize numbers
if final_answer.replace(",", "").isdigit():
final_answer = final_answer.replace(",", "")
return final_answer.strip()
def accuracy_reward(response: str, ground_truth: str) -> float:
match = re.findall(r"(?i)Answer\s*:\s*([^\n]+)", response)
answer = match[-1] if match else "[INVALID]"
if normalize_final_answer(answer) == normalize_final_answer(ground_truth):
return 1.0
else:
return -1.0
def soft_overlong_punishment(response_length: int, max_response_length: int, overlong_buffer_length: int):
expected_len = max_response_length - overlong_buffer_length
if response_length <= expected_len:
return 0.0
elif response_length <= max_response_length:
return (expected_len - response_length) / overlong_buffer_length
else:
return -1.0
def compute_score(
reward_inputs: List[Dict[str, Any]],
max_response_length: int,
overlong_buffer_length: int,
overlong_penalty_factor: float,
) -> List[Dict[str, float]]:
if not isinstance(reward_inputs, list):
raise ValueError("Please use `reward_type=batch` for dapo reward function.")
scores = []
for reward_input in reward_inputs:
response = reward_input["response"][-300:] # The longest answer in MATH-500 has 159 characters
accuracy_score = accuracy_reward(response, reward_input["ground_truth"])
overlong_score = soft_overlong_punishment(
reward_input["response_length"], max_response_length, overlong_buffer_length
)
scores.append(
{
"overall": accuracy_score + overlong_score * overlong_penalty_factor,
"accuracy": accuracy_score,
"overlong": overlong_score,
"accuracy_normalized": 0.5 * (accuracy_score + 1.0),
}
)
return scores