import re
import math
import logging
from difflib import SequenceMatcher
logger = logging.getLogger(__name__)
class MathReward:
def __init__(self, use_reference_comparison=True):
"""
Args:
use_reference_comparison: 是否使用参考答案进行推理过程比较
"""
self.format_pattern = re.compile(r"(.*?)\s*(.*?)", re.DOTALL)
self.use_reference_comparison = use_reference_comparison
# 推理关键词(用于检查推理质量)
self.reasoning_keywords = [
'计算', '因为', '所以', '首先', '然后', '接着', '最后', '根据',
'第一步', '第二步', '第三步', '第', '步', '得到', '等于',
'加', '减', '乘', '除', '=', '+', '-', '*', '/', '÷', '×'
]
def parse_number(self, text):
"""
从文本中解析数值。
支持:整数、小数、分数(1/5)、百分数(20%)、带逗号的数字(1,000)
"""
if not text:
return None
# 预处理:移除空格、货币符号、常见的中文单位
text = text.strip()
clean_text = text.replace(" ", "").replace(",", "").replace("¥", "").replace("$", "")
clean_text = clean_text.replace("千克", "").replace("元", "").replace("个", "").replace("只", "")
clean_text = clean_text.replace("本", "").replace("米", "").replace("人", "")
try:
# 1. 处理百分数 (e.g., "20%")
if "%" in clean_text:
return float(clean_text.replace("%", "")) / 100
# 2. 处理分数 (e.g., "1/5" 或 "42/5")
if "/" in clean_text:
parts = clean_text.split("/")
if len(parts) == 2:
try:
return float(parts[0]) / float(parts[1])
except:
pass
# 3. 处理科学记数法 (e.g., "1.5e-3")
if "e" in clean_text.lower() or "E" in clean_text:
return float(clean_text)
# 4. 提取所有匹配的数字格式
# 匹配 浮点数 或 整数,忽略可能混杂的文字
matches = re.findall(r"[-+]?\d*\.\d+|\d+", clean_text)
if matches:
# 取最后一个作为最终答案
return float(matches[-1])
except Exception as e:
logger.debug(f"解析数字失败: {text}, 错误: {e}")
return None
def check_reasoning_quality(self, think_content):
"""
检查推理过程的质量
返回质量评分 (0.0 - 1.0)
"""
if not think_content:
return 0.0
quality_score = 0.0
# 长度检查
length = len(think_content)
if length >= 100:
quality_score += 0.3
elif length >= 50:
quality_score += 0.15
# 关键词检查
keyword_count = sum(1 for kw in self.reasoning_keywords if kw in think_content)
quality_score += min(keyword_count * 0.05, 0.3)
# 数学表达式检查
math_expressions = re.findall(r'\d+\s*[+\-*/×÷=]\s*\d+', think_content)
if len(math_expressions) > 0:
quality_score += 0.2
if len(math_expressions) >= 3:
quality_score += 0.1
# 结构检查
has_steps = bool(re.search(r'第\d+步|步骤\d+|^\d+[.、]', think_content, re.MULTILINE))
if has_steps:
quality_score += 0.1
return min(quality_score, 1.0)
def compute_reasoning_similarity(self, generated_reasoning, reference_reasoning):
"""
计算生成的推理过程与参考推理过程的相似度
使用序列匹配算法(考虑顺序)
返回相似度分数 (0.0 - 1.0)
"""
if not generated_reasoning or not reference_reasoning:
return 0.0
similarity = SequenceMatcher(None, generated_reasoning, reference_reasoning).ratio()
return similarity
def compute_rewards(self, completions, ground_truths):
"""
计算奖励
Args:
completions: List[str] 模型生成的完整文本
ground_truths: List[dict] 对应的真值
必须包含: 'answer_val': float
可选包含: 'reasoning': str, 'reference_completion': str
Returns:
rewards: List[float]
"""
rewards = []
for completion, gt in zip(completions, ground_truths):
total_reward = 0.0
match = self.format_pattern.search(completion)
if match is None:
rewards.append(-2.0)
continue
think_content = match.group(1).strip()
answer_content = match.group(2).strip()
total_reward += 0.6
reasoning_quality = self.check_reasoning_quality(think_content)
if reasoning_quality < 0.3:
total_reward -= 0.5
else:
total_reward += reasoning_quality * 1.0
if self.use_reference_comparison and 'reasoning' in gt:
reference_reasoning = gt['reasoning']
similarity = self.compute_reasoning_similarity(think_content, reference_reasoning)
if similarity > 0.3:
total_reward += similarity * 0.5
#答案准确性检查---
pred_val = self.parse_number(answer_content)
gt_val = gt['answer_val']
if pred_val is not None:
if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
total_reward += 3.0
else:
try:
relative_error = abs(pred_val - gt_val) / (abs(gt_val) + 1e-8)
if relative_error < 0.1:
total_reward -= 0.3
elif relative_error < 0.5:
total_reward -= 0.8
else:
total_reward -= 1.5
except:
total_reward -= 1.5
else:
total_reward -= 1.0
reasoning_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', think_content)
if reasoning_numbers and pred_val is not None:
answer_in_reasoning = any(
math.isclose(float(num), pred_val, rel_tol=1e-3, abs_tol=1e-3)
for num in reasoning_numbers
)
if answer_in_reasoning:
total_reward += 0.2
rewards.append(total_reward)
return rewards
def compute_metrics(self, completions, ground_truths):
metrics = {
'format_correct': 0,
'reasoning_quality_avg': 0.0,
'answer_correct': 0,
'answer_close': 0,
'total': len(completions)
}
quality_scores = []
for completion, gt in zip(completions, ground_truths):
match = self.format_pattern.search(completion)
if match:
metrics['format_correct'] += 1
think_content = match.group(1).strip()
answer_content = match.group(2).strip()
quality = self.check_reasoning_quality(think_content)
quality_scores.append(quality)
pred_val = self.parse_number(answer_content)
gt_val = gt['answer_val']
if pred_val is not None and gt_val is not None:
if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
metrics['answer_correct'] += 1
elif math.isclose(pred_val, gt_val, rel_tol=0.1, abs_tol=0.1):
metrics['answer_close'] += 1
if quality_scores:
metrics['reasoning_quality_avg'] = sum(quality_scores) / len(quality_scores)
metrics['format_correct_pct'] = metrics['format_correct'] / metrics['total'] * 100
metrics['answer_correct_pct'] = metrics['answer_correct'] / metrics['total'] * 100
metrics['answer_close_pct'] = metrics['answer_close'] / metrics['total'] * 100
return metrics