Update math_verifier.py
Browse files- math_verifier.py +12 -52
math_verifier.py
CHANGED
|
@@ -11,7 +11,6 @@ class MathReward:
|
|
| 11 |
Args:
|
| 12 |
use_reference_comparison: 是否使用参考答案进行推理过程比较
|
| 13 |
"""
|
| 14 |
-
# 编译正则表达式,强制要求 <think> 在前,<answer> 在后
|
| 15 |
self.format_pattern = re.compile(r"<think>(.*?)</think>\s*<answer>(.*?)</answer>", re.DOTALL)
|
| 16 |
self.use_reference_comparison = use_reference_comparison
|
| 17 |
|
|
@@ -58,7 +57,7 @@ class MathReward:
|
|
| 58 |
# 匹配 浮点数 或 整数,忽略可能混杂的文字
|
| 59 |
matches = re.findall(r"[-+]?\d*\.\d+|\d+", clean_text)
|
| 60 |
if matches:
|
| 61 |
-
# 取最后一个作为最终答案
|
| 62 |
return float(matches[-1])
|
| 63 |
|
| 64 |
except Exception as e:
|
|
@@ -77,28 +76,25 @@ class MathReward:
|
|
| 77 |
|
| 78 |
quality_score = 0.0
|
| 79 |
|
| 80 |
-
#
|
| 81 |
length = len(think_content)
|
| 82 |
if length >= 100:
|
| 83 |
quality_score += 0.3
|
| 84 |
elif length >= 50:
|
| 85 |
quality_score += 0.15
|
| 86 |
|
| 87 |
-
#
|
| 88 |
keyword_count = sum(1 for kw in self.reasoning_keywords if kw in think_content)
|
| 89 |
-
# 每出现一个关键词加分,最多加0.3分
|
| 90 |
quality_score += min(keyword_count * 0.05, 0.3)
|
| 91 |
|
| 92 |
-
#
|
| 93 |
-
# 匹配数学运算符或等式
|
| 94 |
math_expressions = re.findall(r'\d+\s*[+\-*/×÷=]\s*\d+', think_content)
|
| 95 |
if len(math_expressions) > 0:
|
| 96 |
quality_score += 0.2
|
| 97 |
-
# 多个表达式说明推理更详细
|
| 98 |
if len(math_expressions) >= 3:
|
| 99 |
quality_score += 0.1
|
| 100 |
|
| 101 |
-
#
|
| 102 |
has_steps = bool(re.search(r'第\d+步|步骤\d+|^\d+[.、]', think_content, re.MULTILINE))
|
| 103 |
if has_steps:
|
| 104 |
quality_score += 0.1
|
|
@@ -114,8 +110,7 @@ class MathReward:
|
|
| 114 |
"""
|
| 115 |
if not generated_reasoning or not reference_reasoning:
|
| 116 |
return 0.0
|
| 117 |
-
|
| 118 |
-
# 使用 difflib 的 SequenceMatcher 计算相似度
|
| 119 |
similarity = SequenceMatcher(None, generated_reasoning, reference_reasoning).ratio()
|
| 120 |
|
| 121 |
return similarity
|
|
@@ -137,77 +132,52 @@ class MathReward:
|
|
| 137 |
|
| 138 |
for completion, gt in zip(completions, ground_truths):
|
| 139 |
total_reward = 0.0
|
| 140 |
-
|
| 141 |
-
# --- 1. 格式与结构检查 ---
|
| 142 |
match = self.format_pattern.search(completion)
|
| 143 |
|
| 144 |
-
# 如果没有匹配到 <think>...</think><answer>...</answer> 结构
|
| 145 |
if match is None:
|
| 146 |
-
# 格式严重错误,给予重罚
|
| 147 |
rewards.append(-2.0)
|
| 148 |
continue
|
| 149 |
-
|
| 150 |
-
# 提取内容
|
| 151 |
think_content = match.group(1).strip()
|
| 152 |
answer_content = match.group(2).strip()
|
| 153 |
-
|
| 154 |
-
# 格式正确的基础分
|
| 155 |
total_reward += 0.6
|
| 156 |
-
|
| 157 |
-
# --- 2. 思考过程质量检查 ---
|
| 158 |
reasoning_quality = self.check_reasoning_quality(think_content)
|
| 159 |
|
| 160 |
if reasoning_quality < 0.3:
|
| 161 |
-
# 推理过程质量太低(可能是敷衍或格式化)
|
| 162 |
total_reward -= 0.5
|
| 163 |
else:
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
# --- 3. 推理过程与参考对比(如果有参考) ---
|
| 168 |
if self.use_reference_comparison and 'reasoning' in gt:
|
| 169 |
reference_reasoning = gt['reasoning']
|
| 170 |
similarity = self.compute_reasoning_similarity(think_content, reference_reasoning)
|
| 171 |
-
|
| 172 |
-
# 相似度奖励(最多0.5分)
|
| 173 |
-
# 注意:不要求完全一致,因为可能有多种正确推理方式
|
| 174 |
if similarity > 0.3:
|
| 175 |
total_reward += similarity * 0.5
|
| 176 |
|
| 177 |
-
#
|
| 178 |
pred_val = self.parse_number(answer_content)
|
| 179 |
gt_val = gt['answer_val']
|
| 180 |
|
| 181 |
if pred_val is not None:
|
| 182 |
-
# 数值比较,允许 float 精度误差
|
| 183 |
if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
|
| 184 |
-
# 答对给予最高奖励
|
| 185 |
total_reward += 3.0
|
| 186 |
else:
|
| 187 |
-
# 答错扣分
|
| 188 |
-
# 根据误差大小调整惩罚
|
| 189 |
try:
|
| 190 |
relative_error = abs(pred_val - gt_val) / (abs(gt_val) + 1e-8)
|
| 191 |
if relative_error < 0.1:
|
| 192 |
-
# 接近正确答案,轻微惩罚
|
| 193 |
total_reward -= 0.3
|
| 194 |
elif relative_error < 0.5:
|
| 195 |
-
# 有一定误差
|
| 196 |
total_reward -= 0.8
|
| 197 |
else:
|
| 198 |
-
# 完全错误
|
| 199 |
total_reward -= 1.5
|
| 200 |
except:
|
| 201 |
total_reward -= 1.5
|
| 202 |
else:
|
| 203 |
-
# <answer> 标签内提取不到有效数字
|
| 204 |
total_reward -= 1.0
|
| 205 |
|
| 206 |
-
# --- 5. 一致性检查:推理过程中的数字应该与答案相关 ---
|
| 207 |
-
# 提取推理过程中出现的所有数字
|
| 208 |
reasoning_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', think_content)
|
| 209 |
if reasoning_numbers and pred_val is not None:
|
| 210 |
-
# 检查答案是否出现在推理过程中
|
| 211 |
answer_in_reasoning = any(
|
| 212 |
math.isclose(float(num), pred_val, rel_tol=1e-3, abs_tol=1e-3)
|
| 213 |
for num in reasoning_numbers
|
|
@@ -220,17 +190,11 @@ class MathReward:
|
|
| 220 |
return rewards
|
| 221 |
|
| 222 |
def compute_metrics(self, completions, ground_truths):
|
| 223 |
-
"""
|
| 224 |
-
计算详细的评估指标(用于分析)
|
| 225 |
-
|
| 226 |
-
Returns:
|
| 227 |
-
dict: 包含各种指标的字典
|
| 228 |
-
"""
|
| 229 |
metrics = {
|
| 230 |
'format_correct': 0,
|
| 231 |
'reasoning_quality_avg': 0.0,
|
| 232 |
'answer_correct': 0,
|
| 233 |
-
'answer_close': 0,
|
| 234 |
'total': len(completions)
|
| 235 |
}
|
| 236 |
|
|
@@ -245,11 +209,8 @@ class MathReward:
|
|
| 245 |
think_content = match.group(1).strip()
|
| 246 |
answer_content = match.group(2).strip()
|
| 247 |
|
| 248 |
-
# 推理质量
|
| 249 |
quality = self.check_reasoning_quality(think_content)
|
| 250 |
quality_scores.append(quality)
|
| 251 |
-
|
| 252 |
-
# 答案准确性
|
| 253 |
pred_val = self.parse_number(answer_content)
|
| 254 |
gt_val = gt['answer_val']
|
| 255 |
|
|
@@ -262,7 +223,6 @@ class MathReward:
|
|
| 262 |
if quality_scores:
|
| 263 |
metrics['reasoning_quality_avg'] = sum(quality_scores) / len(quality_scores)
|
| 264 |
|
| 265 |
-
# 计算百分比
|
| 266 |
metrics['format_correct_pct'] = metrics['format_correct'] / metrics['total'] * 100
|
| 267 |
metrics['answer_correct_pct'] = metrics['answer_correct'] / metrics['total'] * 100
|
| 268 |
metrics['answer_close_pct'] = metrics['answer_close'] / metrics['total'] * 100
|
|
|
|
| 11 |
Args:
|
| 12 |
use_reference_comparison: 是否使用参考答案进行推理过程比较
|
| 13 |
"""
|
|
|
|
| 14 |
self.format_pattern = re.compile(r"<think>(.*?)</think>\s*<answer>(.*?)</answer>", re.DOTALL)
|
| 15 |
self.use_reference_comparison = use_reference_comparison
|
| 16 |
|
|
|
|
| 57 |
# 匹配 浮点数 或 整数,忽略可能混杂的文字
|
| 58 |
matches = re.findall(r"[-+]?\d*\.\d+|\d+", clean_text)
|
| 59 |
if matches:
|
| 60 |
+
# 取最后一个作为最终答案
|
| 61 |
return float(matches[-1])
|
| 62 |
|
| 63 |
except Exception as e:
|
|
|
|
| 76 |
|
| 77 |
quality_score = 0.0
|
| 78 |
|
| 79 |
+
# 长度检查
|
| 80 |
length = len(think_content)
|
| 81 |
if length >= 100:
|
| 82 |
quality_score += 0.3
|
| 83 |
elif length >= 50:
|
| 84 |
quality_score += 0.15
|
| 85 |
|
| 86 |
+
# 关键词检查
|
| 87 |
keyword_count = sum(1 for kw in self.reasoning_keywords if kw in think_content)
|
|
|
|
| 88 |
quality_score += min(keyword_count * 0.05, 0.3)
|
| 89 |
|
| 90 |
+
# 数学表达式检查
|
|
|
|
| 91 |
math_expressions = re.findall(r'\d+\s*[+\-*/×÷=]\s*\d+', think_content)
|
| 92 |
if len(math_expressions) > 0:
|
| 93 |
quality_score += 0.2
|
|
|
|
| 94 |
if len(math_expressions) >= 3:
|
| 95 |
quality_score += 0.1
|
| 96 |
|
| 97 |
+
# 结构检查
|
| 98 |
has_steps = bool(re.search(r'第\d+步|步骤\d+|^\d+[.、]', think_content, re.MULTILINE))
|
| 99 |
if has_steps:
|
| 100 |
quality_score += 0.1
|
|
|
|
| 110 |
"""
|
| 111 |
if not generated_reasoning or not reference_reasoning:
|
| 112 |
return 0.0
|
| 113 |
+
|
|
|
|
| 114 |
similarity = SequenceMatcher(None, generated_reasoning, reference_reasoning).ratio()
|
| 115 |
|
| 116 |
return similarity
|
|
|
|
| 132 |
|
| 133 |
for completion, gt in zip(completions, ground_truths):
|
| 134 |
total_reward = 0.0
|
|
|
|
|
|
|
| 135 |
match = self.format_pattern.search(completion)
|
| 136 |
|
|
|
|
| 137 |
if match is None:
|
|
|
|
| 138 |
rewards.append(-2.0)
|
| 139 |
continue
|
| 140 |
+
|
|
|
|
| 141 |
think_content = match.group(1).strip()
|
| 142 |
answer_content = match.group(2).strip()
|
|
|
|
|
|
|
| 143 |
total_reward += 0.6
|
|
|
|
|
|
|
| 144 |
reasoning_quality = self.check_reasoning_quality(think_content)
|
| 145 |
|
| 146 |
if reasoning_quality < 0.3:
|
|
|
|
| 147 |
total_reward -= 0.5
|
| 148 |
else:
|
| 149 |
+
total_reward += reasoning_quality * 1.0
|
| 150 |
+
|
|
|
|
|
|
|
| 151 |
if self.use_reference_comparison and 'reasoning' in gt:
|
| 152 |
reference_reasoning = gt['reasoning']
|
| 153 |
similarity = self.compute_reasoning_similarity(think_content, reference_reasoning)
|
| 154 |
+
|
|
|
|
|
|
|
| 155 |
if similarity > 0.3:
|
| 156 |
total_reward += similarity * 0.5
|
| 157 |
|
| 158 |
+
#答案准确性检查---
|
| 159 |
pred_val = self.parse_number(answer_content)
|
| 160 |
gt_val = gt['answer_val']
|
| 161 |
|
| 162 |
if pred_val is not None:
|
|
|
|
| 163 |
if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
|
|
|
|
| 164 |
total_reward += 3.0
|
| 165 |
else:
|
|
|
|
|
|
|
| 166 |
try:
|
| 167 |
relative_error = abs(pred_val - gt_val) / (abs(gt_val) + 1e-8)
|
| 168 |
if relative_error < 0.1:
|
|
|
|
| 169 |
total_reward -= 0.3
|
| 170 |
elif relative_error < 0.5:
|
|
|
|
| 171 |
total_reward -= 0.8
|
| 172 |
else:
|
|
|
|
| 173 |
total_reward -= 1.5
|
| 174 |
except:
|
| 175 |
total_reward -= 1.5
|
| 176 |
else:
|
|
|
|
| 177 |
total_reward -= 1.0
|
| 178 |
|
|
|
|
|
|
|
| 179 |
reasoning_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', think_content)
|
| 180 |
if reasoning_numbers and pred_val is not None:
|
|
|
|
| 181 |
answer_in_reasoning = any(
|
| 182 |
math.isclose(float(num), pred_val, rel_tol=1e-3, abs_tol=1e-3)
|
| 183 |
for num in reasoning_numbers
|
|
|
|
| 190 |
return rewards
|
| 191 |
|
| 192 |
def compute_metrics(self, completions, ground_truths):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
metrics = {
|
| 194 |
'format_correct': 0,
|
| 195 |
'reasoning_quality_avg': 0.0,
|
| 196 |
'answer_correct': 0,
|
| 197 |
+
'answer_close': 0,
|
| 198 |
'total': len(completions)
|
| 199 |
}
|
| 200 |
|
|
|
|
| 209 |
think_content = match.group(1).strip()
|
| 210 |
answer_content = match.group(2).strip()
|
| 211 |
|
|
|
|
| 212 |
quality = self.check_reasoning_quality(think_content)
|
| 213 |
quality_scores.append(quality)
|
|
|
|
|
|
|
| 214 |
pred_val = self.parse_number(answer_content)
|
| 215 |
gt_val = gt['answer_val']
|
| 216 |
|
|
|
|
| 223 |
if quality_scores:
|
| 224 |
metrics['reasoning_quality_avg'] = sum(quality_scores) / len(quality_scores)
|
| 225 |
|
|
|
|
| 226 |
metrics['format_correct_pct'] = metrics['format_correct'] / metrics['total'] * 100
|
| 227 |
metrics['answer_correct_pct'] = metrics['answer_correct'] / metrics['total'] * 100
|
| 228 |
metrics['answer_close_pct'] = metrics['answer_close'] / metrics['total'] * 100
|