import re import math import logging from difflib import SequenceMatcher logger = logging.getLogger(__name__) class MathReward: def __init__(self, use_reference_comparison=True): """ Args: use_reference_comparison: 是否使用参考答案进行推理过程比较 """ self.format_pattern = re.compile(r"(.*?)\s*(.*?)", re.DOTALL) self.use_reference_comparison = use_reference_comparison # 推理关键词(用于检查推理质量) self.reasoning_keywords = [ '计算', '因为', '所以', '首先', '然后', '接着', '最后', '根据', '第一步', '第二步', '第三步', '第', '步', '得到', '等于', '加', '减', '乘', '除', '=', '+', '-', '*', '/', '÷', '×' ] def parse_number(self, text): """ 从文本中解析数值。 支持:整数、小数、分数(1/5)、百分数(20%)、带逗号的数字(1,000) """ if not text: return None # 预处理:移除空格、货币符号、常见的中文单位 text = text.strip() clean_text = text.replace(" ", "").replace(",", "").replace("¥", "").replace("$", "") clean_text = clean_text.replace("千克", "").replace("元", "").replace("个", "").replace("只", "") clean_text = clean_text.replace("本", "").replace("米", "").replace("人", "") try: # 1. 处理百分数 (e.g., "20%") if "%" in clean_text: return float(clean_text.replace("%", "")) / 100 # 2. 处理分数 (e.g., "1/5" 或 "42/5") if "/" in clean_text: parts = clean_text.split("/") if len(parts) == 2: try: return float(parts[0]) / float(parts[1]) except: pass # 3. 处理科学记数法 (e.g., "1.5e-3") if "e" in clean_text.lower() or "E" in clean_text: return float(clean_text) # 4. 提取所有匹配的数字格式 # 匹配 浮点数 或 整数,忽略可能混杂的文字 matches = re.findall(r"[-+]?\d*\.\d+|\d+", clean_text) if matches: # 取最后一个作为最终答案 return float(matches[-1]) except Exception as e: logger.debug(f"解析数字失败: {text}, 错误: {e}") return None def check_reasoning_quality(self, think_content): """ 检查推理过程的质量 返回质量评分 (0.0 - 1.0) """ if not think_content: return 0.0 quality_score = 0.0 # 长度检查 length = len(think_content) if length >= 100: quality_score += 0.3 elif length >= 50: quality_score += 0.15 # 关键词检查 keyword_count = sum(1 for kw in self.reasoning_keywords if kw in think_content) quality_score += min(keyword_count * 0.05, 0.3) # 数学表达式检查 math_expressions = re.findall(r'\d+\s*[+\-*/×÷=]\s*\d+', think_content) if len(math_expressions) > 0: quality_score += 0.2 if len(math_expressions) >= 3: quality_score += 0.1 # 结构检查 has_steps = bool(re.search(r'第\d+步|步骤\d+|^\d+[.、]', think_content, re.MULTILINE)) if has_steps: quality_score += 0.1 return min(quality_score, 1.0) def compute_reasoning_similarity(self, generated_reasoning, reference_reasoning): """ 计算生成的推理过程与参考推理过程的相似度 使用序列匹配算法(考虑顺序) 返回相似度分数 (0.0 - 1.0) """ if not generated_reasoning or not reference_reasoning: return 0.0 similarity = SequenceMatcher(None, generated_reasoning, reference_reasoning).ratio() return similarity def compute_rewards(self, completions, ground_truths): """ 计算奖励 Args: completions: List[str] 模型生成的完整文本 ground_truths: List[dict] 对应的真值 必须包含: 'answer_val': float 可选包含: 'reasoning': str, 'reference_completion': str Returns: rewards: List[float] """ rewards = [] for completion, gt in zip(completions, ground_truths): total_reward = 0.0 match = self.format_pattern.search(completion) if match is None: rewards.append(-2.0) continue think_content = match.group(1).strip() answer_content = match.group(2).strip() total_reward += 0.6 reasoning_quality = self.check_reasoning_quality(think_content) if reasoning_quality < 0.3: total_reward -= 0.5 else: total_reward += reasoning_quality * 1.0 if self.use_reference_comparison and 'reasoning' in gt: reference_reasoning = gt['reasoning'] similarity = self.compute_reasoning_similarity(think_content, reference_reasoning) if similarity > 0.3: total_reward += similarity * 0.5 #答案准确性检查--- pred_val = self.parse_number(answer_content) gt_val = gt['answer_val'] if pred_val is not None: if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4): total_reward += 3.0 else: try: relative_error = abs(pred_val - gt_val) / (abs(gt_val) + 1e-8) if relative_error < 0.1: total_reward -= 0.3 elif relative_error < 0.5: total_reward -= 0.8 else: total_reward -= 1.5 except: total_reward -= 1.5 else: total_reward -= 1.0 reasoning_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', think_content) if reasoning_numbers and pred_val is not None: answer_in_reasoning = any( math.isclose(float(num), pred_val, rel_tol=1e-3, abs_tol=1e-3) for num in reasoning_numbers ) if answer_in_reasoning: total_reward += 0.2 rewards.append(total_reward) return rewards def compute_metrics(self, completions, ground_truths): metrics = { 'format_correct': 0, 'reasoning_quality_avg': 0.0, 'answer_correct': 0, 'answer_close': 0, 'total': len(completions) } quality_scores = [] for completion, gt in zip(completions, ground_truths): match = self.format_pattern.search(completion) if match: metrics['format_correct'] += 1 think_content = match.group(1).strip() answer_content = match.group(2).strip() quality = self.check_reasoning_quality(think_content) quality_scores.append(quality) pred_val = self.parse_number(answer_content) gt_val = gt['answer_val'] if pred_val is not None and gt_val is not None: if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4): metrics['answer_correct'] += 1 elif math.isclose(pred_val, gt_val, rel_tol=0.1, abs_tol=0.1): metrics['answer_close'] += 1 if quality_scores: metrics['reasoning_quality_avg'] = sum(quality_scores) / len(quality_scores) metrics['format_correct_pct'] = metrics['format_correct'] / metrics['total'] * 100 metrics['answer_correct_pct'] = metrics['answer_correct'] / metrics['total'] * 100 metrics['answer_close_pct'] = metrics['answer_close'] / metrics['total'] * 100 return metrics