File size: 8,702 Bytes
263d741 68004aa 263d741 68004aa 263d741 68004aa 263d741 68004aa 263d741 68004aa 263d741 68004aa 263d741 68004aa 263d741 68004aa 263d741 68004aa 263d741 68004aa 263d741 68004aa 263d741 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | import re
import math
import logging
from difflib import SequenceMatcher
logger = logging.getLogger(__name__)
class MathReward:
def __init__(self, use_reference_comparison=True):
"""
Args:
use_reference_comparison: 是否使用参考答案进行推理过程比较
"""
self.format_pattern = re.compile(r"<think>(.*?)</think>\s*<answer>(.*?)</answer>", re.DOTALL)
self.use_reference_comparison = use_reference_comparison
# 推理关键词(用于检查推理质量)
self.reasoning_keywords = [
'计算', '因为', '所以', '首先', '然后', '接着', '最后', '根据',
'第一步', '第二步', '第三步', '第', '步', '得到', '等于',
'加', '减', '乘', '除', '=', '+', '-', '*', '/', '÷', '×'
]
def parse_number(self, text):
"""
从文本中解析数值。
支持:整数、小数、分数(1/5)、百分数(20%)、带逗号的数字(1,000)
"""
if not text:
return None
# 预处理:移除空格、货币符号、常见的中文单位
text = text.strip()
clean_text = text.replace(" ", "").replace(",", "").replace("¥", "").replace("$", "")
clean_text = clean_text.replace("千克", "").replace("元", "").replace("个", "").replace("只", "")
clean_text = clean_text.replace("本", "").replace("米", "").replace("人", "")
try:
# 1. 处理百分数 (e.g., "20%")
if "%" in clean_text:
return float(clean_text.replace("%", "")) / 100
# 2. 处理分数 (e.g., "1/5" 或 "42/5")
if "/" in clean_text:
parts = clean_text.split("/")
if len(parts) == 2:
try:
return float(parts[0]) / float(parts[1])
except:
pass
# 3. 处理科学记数法 (e.g., "1.5e-3")
if "e" in clean_text.lower() or "E" in clean_text:
return float(clean_text)
# 4. 提取所有匹配的数字格式
# 匹配 浮点数 或 整数,忽略可能混杂的文字
matches = re.findall(r"[-+]?\d*\.\d+|\d+", clean_text)
if matches:
# 取最后一个作为最终答案
return float(matches[-1])
except Exception as e:
logger.debug(f"解析数字失败: {text}, 错误: {e}")
return None
def check_reasoning_quality(self, think_content):
"""
检查推理过程的质量
返回质量评分 (0.0 - 1.0)
"""
if not think_content:
return 0.0
quality_score = 0.0
# 长度检查
length = len(think_content)
if length >= 100:
quality_score += 0.3
elif length >= 50:
quality_score += 0.15
# 关键词检查
keyword_count = sum(1 for kw in self.reasoning_keywords if kw in think_content)
quality_score += min(keyword_count * 0.05, 0.3)
# 数学表达式检查
math_expressions = re.findall(r'\d+\s*[+\-*/×÷=]\s*\d+', think_content)
if len(math_expressions) > 0:
quality_score += 0.2
if len(math_expressions) >= 3:
quality_score += 0.1
# 结构检查
has_steps = bool(re.search(r'第\d+步|步骤\d+|^\d+[.、]', think_content, re.MULTILINE))
if has_steps:
quality_score += 0.1
return min(quality_score, 1.0)
def compute_reasoning_similarity(self, generated_reasoning, reference_reasoning):
"""
计算生成的推理过程与参考推理过程的相似度
使用序列匹配算法(考虑顺序)
返回相似度分数 (0.0 - 1.0)
"""
if not generated_reasoning or not reference_reasoning:
return 0.0
similarity = SequenceMatcher(None, generated_reasoning, reference_reasoning).ratio()
return similarity
def compute_rewards(self, completions, ground_truths):
"""
计算奖励
Args:
completions: List[str] 模型生成的完整文本
ground_truths: List[dict] 对应的真值
必须包含: 'answer_val': float
可选包含: 'reasoning': str, 'reference_completion': str
Returns:
rewards: List[float]
"""
rewards = []
for completion, gt in zip(completions, ground_truths):
total_reward = 0.0
match = self.format_pattern.search(completion)
if match is None:
rewards.append(-2.0)
continue
think_content = match.group(1).strip()
answer_content = match.group(2).strip()
total_reward += 0.6
reasoning_quality = self.check_reasoning_quality(think_content)
if reasoning_quality < 0.3:
total_reward -= 0.5
else:
total_reward += reasoning_quality * 1.0
if self.use_reference_comparison and 'reasoning' in gt:
reference_reasoning = gt['reasoning']
similarity = self.compute_reasoning_similarity(think_content, reference_reasoning)
if similarity > 0.3:
total_reward += similarity * 0.5
#答案准确性检查---
pred_val = self.parse_number(answer_content)
gt_val = gt['answer_val']
if pred_val is not None:
if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
total_reward += 3.0
else:
try:
relative_error = abs(pred_val - gt_val) / (abs(gt_val) + 1e-8)
if relative_error < 0.1:
total_reward -= 0.3
elif relative_error < 0.5:
total_reward -= 0.8
else:
total_reward -= 1.5
except:
total_reward -= 1.5
else:
total_reward -= 1.0
reasoning_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', think_content)
if reasoning_numbers and pred_val is not None:
answer_in_reasoning = any(
math.isclose(float(num), pred_val, rel_tol=1e-3, abs_tol=1e-3)
for num in reasoning_numbers
)
if answer_in_reasoning:
total_reward += 0.2
rewards.append(total_reward)
return rewards
def compute_metrics(self, completions, ground_truths):
metrics = {
'format_correct': 0,
'reasoning_quality_avg': 0.0,
'answer_correct': 0,
'answer_close': 0,
'total': len(completions)
}
quality_scores = []
for completion, gt in zip(completions, ground_truths):
match = self.format_pattern.search(completion)
if match:
metrics['format_correct'] += 1
think_content = match.group(1).strip()
answer_content = match.group(2).strip()
quality = self.check_reasoning_quality(think_content)
quality_scores.append(quality)
pred_val = self.parse_number(answer_content)
gt_val = gt['answer_val']
if pred_val is not None and gt_val is not None:
if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
metrics['answer_correct'] += 1
elif math.isclose(pred_val, gt_val, rel_tol=0.1, abs_tol=0.1):
metrics['answer_close'] += 1
if quality_scores:
metrics['reasoning_quality_avg'] = sum(quality_scores) / len(quality_scores)
metrics['format_correct_pct'] = metrics['format_correct'] / metrics['total'] * 100
metrics['answer_correct_pct'] = metrics['answer_correct'] / metrics['total'] * 100
metrics['answer_close_pct'] = metrics['answer_close'] / metrics['total'] * 100
return metrics |