szxllm commited on
Commit
68004aa
·
verified ·
1 Parent(s): 30028f1

Update math_verifier.py

Browse files
Files changed (1) hide show
  1. math_verifier.py +12 -52
math_verifier.py CHANGED
@@ -11,7 +11,6 @@ class MathReward:
11
  Args:
12
  use_reference_comparison: 是否使用参考答案进行推理过程比较
13
  """
14
- # 编译正则表达式,强制要求 <think> 在前,<answer> 在后
15
  self.format_pattern = re.compile(r"<think>(.*?)</think>\s*<answer>(.*?)</answer>", re.DOTALL)
16
  self.use_reference_comparison = use_reference_comparison
17
 
@@ -58,7 +57,7 @@ class MathReward:
58
  # 匹配 浮点数 或 整数,忽略可能混杂的文字
59
  matches = re.findall(r"[-+]?\d*\.\d+|\d+", clean_text)
60
  if matches:
61
- # 取最后一个作为最终答案(通常答案在最后)
62
  return float(matches[-1])
63
 
64
  except Exception as e:
@@ -77,28 +76,25 @@ class MathReward:
77
 
78
  quality_score = 0.0
79
 
80
- # 1. 长度检查(基础)
81
  length = len(think_content)
82
  if length >= 100:
83
  quality_score += 0.3
84
  elif length >= 50:
85
  quality_score += 0.15
86
 
87
- # 2. 关键词检查(推理步骤标识)
88
  keyword_count = sum(1 for kw in self.reasoning_keywords if kw in think_content)
89
- # 每出现一个关键词加分,最多加0.3分
90
  quality_score += min(keyword_count * 0.05, 0.3)
91
 
92
- # 3. 数学表达式检查(是否包含计算过程)
93
- # 匹配数学运算符或等式
94
  math_expressions = re.findall(r'\d+\s*[+\-*/×÷=]\s*\d+', think_content)
95
  if len(math_expressions) > 0:
96
  quality_score += 0.2
97
- # 多个表达式说明推理更详细
98
  if len(math_expressions) >= 3:
99
  quality_score += 0.1
100
 
101
- # 4. 结构检查(是否有步骤分隔)
102
  has_steps = bool(re.search(r'第\d+步|步骤\d+|^\d+[.、]', think_content, re.MULTILINE))
103
  if has_steps:
104
  quality_score += 0.1
@@ -114,8 +110,7 @@ class MathReward:
114
  """
115
  if not generated_reasoning or not reference_reasoning:
116
  return 0.0
117
-
118
- # 使用 difflib 的 SequenceMatcher 计算相似度
119
  similarity = SequenceMatcher(None, generated_reasoning, reference_reasoning).ratio()
120
 
121
  return similarity
@@ -137,77 +132,52 @@ class MathReward:
137
 
138
  for completion, gt in zip(completions, ground_truths):
139
  total_reward = 0.0
140
-
141
- # --- 1. 格式与结构检查 ---
142
  match = self.format_pattern.search(completion)
143
 
144
- # 如果没有匹配到 <think>...</think><answer>...</answer> 结构
145
  if match is None:
146
- # 格式严重错误,给予重罚
147
  rewards.append(-2.0)
148
  continue
149
-
150
- # 提取内容
151
  think_content = match.group(1).strip()
152
  answer_content = match.group(2).strip()
153
-
154
- # 格式正确的基础分
155
  total_reward += 0.6
156
-
157
- # --- 2. 思考过程质量检查 ---
158
  reasoning_quality = self.check_reasoning_quality(think_content)
159
 
160
  if reasoning_quality < 0.3:
161
- # 推理过程质量太低(可能是敷衍或格式化)
162
  total_reward -= 0.5
163
  else:
164
- # 推理质量越高,奖励越多
165
- total_reward += reasoning_quality * 1.0 # 最多1.0分
166
-
167
- # --- 3. 推理过程与参考对比(如果有参考) ---
168
  if self.use_reference_comparison and 'reasoning' in gt:
169
  reference_reasoning = gt['reasoning']
170
  similarity = self.compute_reasoning_similarity(think_content, reference_reasoning)
171
-
172
- # 相似度奖励(最多0.5分)
173
- # 注意:不要求完全一致,因为可能有多种正确推理方式
174
  if similarity > 0.3:
175
  total_reward += similarity * 0.5
176
 
177
- # --- 4. 答案准确性检查(最重要) ---
178
  pred_val = self.parse_number(answer_content)
179
  gt_val = gt['answer_val']
180
 
181
  if pred_val is not None:
182
- # 数值比较,允许 float 精度误差
183
  if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
184
- # 答对给予最高奖励
185
  total_reward += 3.0
186
  else:
187
- # 答错扣分
188
- # 根据误差大小调整惩罚
189
  try:
190
  relative_error = abs(pred_val - gt_val) / (abs(gt_val) + 1e-8)
191
  if relative_error < 0.1:
192
- # 接近正确答案,轻微惩罚
193
  total_reward -= 0.3
194
  elif relative_error < 0.5:
195
- # 有一定误差
196
  total_reward -= 0.8
197
  else:
198
- # 完全错误
199
  total_reward -= 1.5
200
  except:
201
  total_reward -= 1.5
202
  else:
203
- # <answer> 标签内提取不到有效数字
204
  total_reward -= 1.0
205
 
206
- # --- 5. 一致性检查:推理过程中的数字应该与答案相关 ---
207
- # 提取推理过程中出现的所有数字
208
  reasoning_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', think_content)
209
  if reasoning_numbers and pred_val is not None:
210
- # 检查答案是否出现在推理过程中
211
  answer_in_reasoning = any(
212
  math.isclose(float(num), pred_val, rel_tol=1e-3, abs_tol=1e-3)
213
  for num in reasoning_numbers
@@ -220,17 +190,11 @@ class MathReward:
220
  return rewards
221
 
222
  def compute_metrics(self, completions, ground_truths):
223
- """
224
- 计算详细的评估指标(用于分析)
225
-
226
- Returns:
227
- dict: 包含各种指标的字典
228
- """
229
  metrics = {
230
  'format_correct': 0,
231
  'reasoning_quality_avg': 0.0,
232
  'answer_correct': 0,
233
- 'answer_close': 0, # 答案接近但不完全正确
234
  'total': len(completions)
235
  }
236
 
@@ -245,11 +209,8 @@ class MathReward:
245
  think_content = match.group(1).strip()
246
  answer_content = match.group(2).strip()
247
 
248
- # 推理质量
249
  quality = self.check_reasoning_quality(think_content)
250
  quality_scores.append(quality)
251
-
252
- # 答案准确性
253
  pred_val = self.parse_number(answer_content)
254
  gt_val = gt['answer_val']
255
 
@@ -262,7 +223,6 @@ class MathReward:
262
  if quality_scores:
263
  metrics['reasoning_quality_avg'] = sum(quality_scores) / len(quality_scores)
264
 
265
- # 计算百分比
266
  metrics['format_correct_pct'] = metrics['format_correct'] / metrics['total'] * 100
267
  metrics['answer_correct_pct'] = metrics['answer_correct'] / metrics['total'] * 100
268
  metrics['answer_close_pct'] = metrics['answer_close'] / metrics['total'] * 100
 
11
  Args:
12
  use_reference_comparison: 是否使用参考答案进行推理过程比较
13
  """
 
14
  self.format_pattern = re.compile(r"<think>(.*?)</think>\s*<answer>(.*?)</answer>", re.DOTALL)
15
  self.use_reference_comparison = use_reference_comparison
16
 
 
57
  # 匹配 浮点数 或 整数,忽略可能混杂的文字
58
  matches = re.findall(r"[-+]?\d*\.\d+|\d+", clean_text)
59
  if matches:
60
+ # 取最后一个作为最终答案
61
  return float(matches[-1])
62
 
63
  except Exception as e:
 
76
 
77
  quality_score = 0.0
78
 
79
+ # 长度检查
80
  length = len(think_content)
81
  if length >= 100:
82
  quality_score += 0.3
83
  elif length >= 50:
84
  quality_score += 0.15
85
 
86
+ # 关键词检查
87
  keyword_count = sum(1 for kw in self.reasoning_keywords if kw in think_content)
 
88
  quality_score += min(keyword_count * 0.05, 0.3)
89
 
90
+ # 数学表达式检查
 
91
  math_expressions = re.findall(r'\d+\s*[+\-*/×÷=]\s*\d+', think_content)
92
  if len(math_expressions) > 0:
93
  quality_score += 0.2
 
94
  if len(math_expressions) >= 3:
95
  quality_score += 0.1
96
 
97
+ # 结构检查
98
  has_steps = bool(re.search(r'第\d+步|步骤\d+|^\d+[.、]', think_content, re.MULTILINE))
99
  if has_steps:
100
  quality_score += 0.1
 
110
  """
111
  if not generated_reasoning or not reference_reasoning:
112
  return 0.0
113
+
 
114
  similarity = SequenceMatcher(None, generated_reasoning, reference_reasoning).ratio()
115
 
116
  return similarity
 
132
 
133
  for completion, gt in zip(completions, ground_truths):
134
  total_reward = 0.0
 
 
135
  match = self.format_pattern.search(completion)
136
 
 
137
  if match is None:
 
138
  rewards.append(-2.0)
139
  continue
140
+
 
141
  think_content = match.group(1).strip()
142
  answer_content = match.group(2).strip()
 
 
143
  total_reward += 0.6
 
 
144
  reasoning_quality = self.check_reasoning_quality(think_content)
145
 
146
  if reasoning_quality < 0.3:
 
147
  total_reward -= 0.5
148
  else:
149
+ total_reward += reasoning_quality * 1.0
150
+
 
 
151
  if self.use_reference_comparison and 'reasoning' in gt:
152
  reference_reasoning = gt['reasoning']
153
  similarity = self.compute_reasoning_similarity(think_content, reference_reasoning)
154
+
 
 
155
  if similarity > 0.3:
156
  total_reward += similarity * 0.5
157
 
158
+ #答案准确性检查---
159
  pred_val = self.parse_number(answer_content)
160
  gt_val = gt['answer_val']
161
 
162
  if pred_val is not None:
 
163
  if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
 
164
  total_reward += 3.0
165
  else:
 
 
166
  try:
167
  relative_error = abs(pred_val - gt_val) / (abs(gt_val) + 1e-8)
168
  if relative_error < 0.1:
 
169
  total_reward -= 0.3
170
  elif relative_error < 0.5:
 
171
  total_reward -= 0.8
172
  else:
 
173
  total_reward -= 1.5
174
  except:
175
  total_reward -= 1.5
176
  else:
 
177
  total_reward -= 1.0
178
 
 
 
179
  reasoning_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', think_content)
180
  if reasoning_numbers and pred_val is not None:
 
181
  answer_in_reasoning = any(
182
  math.isclose(float(num), pred_val, rel_tol=1e-3, abs_tol=1e-3)
183
  for num in reasoning_numbers
 
190
  return rewards
191
 
192
  def compute_metrics(self, completions, ground_truths):
 
 
 
 
 
 
193
  metrics = {
194
  'format_correct': 0,
195
  'reasoning_quality_avg': 0.0,
196
  'answer_correct': 0,
197
+ 'answer_close': 0,
198
  'total': len(completions)
199
  }
200
 
 
209
  think_content = match.group(1).strip()
210
  answer_content = match.group(2).strip()
211
 
 
212
  quality = self.check_reasoning_quality(think_content)
213
  quality_scores.append(quality)
 
 
214
  pred_val = self.parse_number(answer_content)
215
  gt_val = gt['answer_val']
216
 
 
223
  if quality_scores:
224
  metrics['reasoning_quality_avg'] = sum(quality_scores) / len(quality_scores)
225
 
 
226
  metrics['format_correct_pct'] = metrics['format_correct'] / metrics['total'] * 100
227
  metrics['answer_correct_pct'] = metrics['answer_correct'] / metrics['total'] * 100
228
  metrics['answer_close_pct'] = metrics['answer_close'] / metrics['total'] * 100