|
|
"""
|
|
|
Collection of various reward signal for the arithmetic problem.
|
|
|
"""
|
|
|
|
|
|
import logging
|
|
|
import re
|
|
|
|
|
|
from src.utils.string_helper import (
|
|
|
extract_answers_from_completions,
|
|
|
extract_response_from_completions,
|
|
|
)
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger("rewards")
|
|
|
|
|
|
|
|
|
def _is_valid_arithmetic_expression(expression: str) -> bool:
|
|
|
"""
|
|
|
Check if a string is a valid arithmetic expression containing only:
|
|
|
- Numbers (integers only)
|
|
|
- Arithmetic operators: +, -, x, /
|
|
|
- Whitespace
|
|
|
|
|
|
Args:
|
|
|
expression: The expression to validate
|
|
|
|
|
|
Returns:
|
|
|
bool: True if valid arithmetic expression, False otherwise
|
|
|
"""
|
|
|
if not expression or not expression.strip():
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
pattern = r"^[\d\s\+\-x\/]+$"
|
|
|
|
|
|
|
|
|
if not re.match(pattern, expression):
|
|
|
return False
|
|
|
|
|
|
|
|
|
has_number = re.search(r"\d", expression)
|
|
|
has_operator = re.search(r"[\+\-x\/]", expression)
|
|
|
|
|
|
if not (has_number and has_operator):
|
|
|
return False
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
normalized = expression.replace("x", "*")
|
|
|
|
|
|
|
|
|
normalized = "".join(normalized.split())
|
|
|
|
|
|
|
|
|
if re.search(r"[\+\-\*\/]{2,}", normalized):
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
eval(normalized)
|
|
|
return True
|
|
|
|
|
|
except (SyntaxError, ValueError):
|
|
|
|
|
|
return False
|
|
|
except ZeroDivisionError:
|
|
|
|
|
|
return True
|
|
|
except:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
def _calculate_distance_based_reward(answer: str, correct_answer: int) -> float:
|
|
|
"""
|
|
|
Calculate reward based on distance from the correct answer.
|
|
|
|
|
|
Uses linear scaling: reward = max(0, max_reward - (distance * penalty_per_unit))
|
|
|
|
|
|
Args:
|
|
|
answer: The arithmetic expression to evaluate
|
|
|
correct_answer: The expected result
|
|
|
|
|
|
Returns:
|
|
|
float: Reward between 0.0 and 2.0 based on distance from correct answer
|
|
|
"""
|
|
|
if not answer or not answer.strip():
|
|
|
return 0.0
|
|
|
|
|
|
|
|
|
if not _is_valid_arithmetic_expression(answer):
|
|
|
return 0.0
|
|
|
|
|
|
try:
|
|
|
|
|
|
normalized = answer.replace("x", "*")
|
|
|
|
|
|
|
|
|
result = eval(normalized)
|
|
|
|
|
|
|
|
|
if isinstance(result, (int, float)):
|
|
|
distance = abs(result - correct_answer)
|
|
|
|
|
|
|
|
|
if distance < 0.0001:
|
|
|
return 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_reward = 2.0
|
|
|
penalty_per_unit = 0.2
|
|
|
|
|
|
reward = max_reward - (distance * penalty_per_unit)
|
|
|
|
|
|
|
|
|
return max(0.0, reward)
|
|
|
|
|
|
return 0.0
|
|
|
|
|
|
except (SyntaxError, ValueError, ZeroDivisionError, OverflowError):
|
|
|
|
|
|
return 0.0
|
|
|
except:
|
|
|
|
|
|
return 0.0
|
|
|
|
|
|
|
|
|
def format_reward_functiondef(
|
|
|
completions: list[list[dict[str, str]]], **kwargs: dict[str, any]
|
|
|
) -> list[float]:
|
|
|
"""
|
|
|
Reward function that checks if a completion contains <think>...</think> and
|
|
|
<answer>...</answer> sections.
|
|
|
|
|
|
Args:
|
|
|
completions: List of completions of the format:
|
|
|
[
|
|
|
[
|
|
|
{"role": "user", "content": "..."},
|
|
|
{"role": "assistant", "content": "..."},
|
|
|
]
|
|
|
]
|
|
|
|
|
|
Returns:
|
|
|
List of rewards.
|
|
|
"""
|
|
|
pattern = re.compile(r"<think>.*?</think>.*?<answer>.*?</answer>", re.DOTALL)
|
|
|
responses = extract_response_from_completions(completions)
|
|
|
matches = [bool(pattern.search(response)) for response in responses]
|
|
|
return [1.0 if match else 0.0 for match in matches]
|
|
|
|
|
|
|
|
|
def arithmetic_format_reward_function(
|
|
|
completions: list[list[dict[str, str]]],
|
|
|
**kwargs: dict[str, any],
|
|
|
) -> list[float]:
|
|
|
"""
|
|
|
Reward function that checks if the content of the answer tag is a valid arithmetic expression.
|
|
|
|
|
|
The answer should contain only numbers, arithmetic operators (+, -, x, /),
|
|
|
and spaces. Examples of valid formats:
|
|
|
- "1 + 2 x 6 / 3"
|
|
|
- "2 x 1 + 3 - 1"
|
|
|
- "4 + 5 x 2 - 1"
|
|
|
|
|
|
Args:
|
|
|
completions: List of completions of the format:
|
|
|
[
|
|
|
[
|
|
|
{"role": "user", "content": "..."},
|
|
|
{"role": "assistant", "content": "..."},
|
|
|
]
|
|
|
]
|
|
|
|
|
|
Returns:
|
|
|
List of rewards (1.0 for valid arithmetic expressions, 0.0 otherwise).
|
|
|
"""
|
|
|
|
|
|
answers = extract_answers_from_completions(completions)
|
|
|
|
|
|
return [
|
|
|
1.0 if _is_valid_arithmetic_expression(answer) else 0.0 for answer in answers
|
|
|
]
|
|
|
|
|
|
|
|
|
def correctness_reward_function(
|
|
|
completions: list[list[dict[str, str]]], **kwargs: dict[str, any]
|
|
|
) -> list[float]:
|
|
|
"""
|
|
|
Reward function that provides rewards based on how close the arithmetic answer is to the correct result.
|
|
|
|
|
|
The reward is calculated using linear scaling:
|
|
|
- Perfect match (distance = 0): reward = 2.0
|
|
|
- Each unit of distance reduces reward by 0.2 points
|
|
|
- Minimum reward is 0.0
|
|
|
- Invalid expressions get 0.0
|
|
|
|
|
|
Args:
|
|
|
completions: List of completions of the format:
|
|
|
[
|
|
|
[
|
|
|
{"role": "user", "content": "..."},
|
|
|
{"role": "assistant", "content": "..."},
|
|
|
]
|
|
|
]
|
|
|
**kwargs: Must contain 'correct_answer' key with the expected result
|
|
|
|
|
|
Returns:
|
|
|
List of rewards (0.0 to 2.0 based on distance from correct answer).
|
|
|
|
|
|
Raises:
|
|
|
ValueError: If the correct answer is not provided in the kwargs.
|
|
|
"""
|
|
|
|
|
|
correct_answer = kwargs["correct_answer"]
|
|
|
|
|
|
|
|
|
answers = extract_answers_from_completions(completions)
|
|
|
completions = [completion[-1]["content"] for completion in completions]
|
|
|
|
|
|
|
|
|
logger.info("First question: %s", completions[0])
|
|
|
logger.info("First answer: %s", answers[0])
|
|
|
|
|
|
return [
|
|
|
_calculate_distance_based_reward(answer, correct_answer) for answer in answers
|
|
|
]
|
|
|
|
|
|
|
|
|
def mathematical_correctness_reward_function(
|
|
|
completions: list[str], **kwargs
|
|
|
) -> list[float]:
|
|
|
"""
|
|
|
Evaluates completions based on Mathematical correctness of the answer
|
|
|
|
|
|
Args:
|
|
|
completions: Generated outputs
|
|
|
target: Expected answers
|
|
|
**kwargs: Additional keyword arguments
|
|
|
|
|
|
Returns:
|
|
|
list[float]: Reward scores (1.0 for correct, 0.0 for incorrect)
|
|
|
"""
|
|
|
completions = [completion[-1]["content"] for completion in completions]
|
|
|
target = kwargs["correct_answer"]
|
|
|
first_nums = kwargs["num1"]
|
|
|
second_nums = kwargs["num2"]
|
|
|
third_nums = kwargs["num3"]
|
|
|
fourth_nums = kwargs["num4"]
|
|
|
rewards = []
|
|
|
|
|
|
|
|
|
logger.info("Completion:\n%s", completions[0])
|
|
|
|
|
|
for completion, gt, first_num, second_num, third_num, fourth_num in zip(
|
|
|
completions,
|
|
|
target,
|
|
|
first_nums,
|
|
|
second_nums,
|
|
|
third_nums,
|
|
|
fourth_nums,
|
|
|
strict=False,
|
|
|
):
|
|
|
reward = 0.0
|
|
|
try:
|
|
|
|
|
|
match = re.search(r"<answer>(.*?)<\/answer>", completion, re.DOTALL)
|
|
|
if match is None:
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
logger.info(
|
|
|
"β β FORMAT ERROR: No <answer> tags found in completion β"
|
|
|
)
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
|
|
)
|
|
|
logger.info(
|
|
|
"β Completion snippet: %-47s β",
|
|
|
completion[:47] + "..." if len(completion) > 47 else completion,
|
|
|
)
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
rewards.append(reward)
|
|
|
continue
|
|
|
|
|
|
|
|
|
reward += 1.0
|
|
|
|
|
|
|
|
|
equation = match.group(1).strip()
|
|
|
if "=" in equation:
|
|
|
equation = equation.split("=")[0]
|
|
|
|
|
|
|
|
|
used_numbers = [int(n) for n in re.findall(r"\d+", equation)]
|
|
|
|
|
|
|
|
|
correct_numbers = [first_num, second_num, third_num, fourth_num]
|
|
|
if sorted(used_numbers) != sorted(correct_numbers):
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
logger.info(
|
|
|
"β β NUMBER USAGE ERROR: Incorrect numbers used β"
|
|
|
)
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
|
|
)
|
|
|
logger.info("β Equation: %-57s β", equation[:57])
|
|
|
logger.info("β Expected numbers: %-51s β", str(correct_numbers))
|
|
|
logger.info("β Used numbers: %-55s β", str(used_numbers))
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
rewards.append(reward)
|
|
|
continue
|
|
|
|
|
|
|
|
|
reward += 1.0
|
|
|
|
|
|
|
|
|
allowed_pattern = r"^[\d+\-*/.\s]+$"
|
|
|
if not re.match(allowed_pattern, equation):
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
logger.info(
|
|
|
"β β INVALID CHARACTERS: Equation contains disallowed characters β"
|
|
|
)
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
|
|
)
|
|
|
logger.info("β Equation: %-57s β", equation[:57])
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
rewards.append(reward)
|
|
|
continue
|
|
|
|
|
|
|
|
|
reward += 1.0
|
|
|
|
|
|
|
|
|
result = eval(equation, {"__builtins__": None}, {})
|
|
|
|
|
|
|
|
|
if abs(float(result) - float(gt)) < 1e-5:
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
logger.info(
|
|
|
"β β
CORRECT ANSWER: Perfect match! β"
|
|
|
)
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
|
|
)
|
|
|
logger.info(
|
|
|
"β Equation: %-35s = %-20s β", equation[:35], str(result)[:20]
|
|
|
)
|
|
|
logger.info("β Target: %-59s β", str(gt))
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
reward += 4.0
|
|
|
rewards.append(reward)
|
|
|
else:
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
logger.info(
|
|
|
"β β WRONG RESULT: Equation evaluated to incorrect value β"
|
|
|
)
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
|
|
)
|
|
|
logger.info(
|
|
|
"β Equation: %-35s = %-20s β", equation[:35], str(result)[:20]
|
|
|
)
|
|
|
logger.info("β Expected: %-57s β", str(gt))
|
|
|
logger.info(
|
|
|
"β Difference: %-55s β", str(abs(float(result) - float(gt)))
|
|
|
)
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
rewards.append(reward)
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
logger.info(
|
|
|
"β β EVALUATION ERROR: Exception occurred during processing β"
|
|
|
)
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
|
|
)
|
|
|
logger.info("β Error: %-61s β", str(e)[:61])
|
|
|
logger.info(
|
|
|
"β Equation: %-57s β",
|
|
|
(equation if "equation" in locals() else "N/A")[:57],
|
|
|
)
|
|
|
logger.info(
|
|
|
"βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
|
|
)
|
|
|
rewards.append(reward)
|
|
|
return rewards
|
|
|
|