| import re |
| import math |
| import logging |
| from difflib import SequenceMatcher |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class MathReward: |
| def __init__(self, use_reference_comparison=True): |
| """ |
| Args: |
| use_reference_comparison: 是否使用参考答案进行推理过程比较 |
| """ |
| self.format_pattern = re.compile(r"<think>(.*?)</think>\s*<answer>(.*?)</answer>", re.DOTALL) |
| self.use_reference_comparison = use_reference_comparison |
| |
| |
| self.reasoning_keywords = [ |
| '计算', '因为', '所以', '首先', '然后', '接着', '最后', '根据', |
| '第一步', '第二步', '第三步', '第', '步', '得到', '等于', |
| '加', '减', '乘', '除', '=', '+', '-', '*', '/', '÷', '×' |
| ] |
| |
| def parse_number(self, text): |
| """ |
| 从文本中解析数值。 |
| 支持:整数、小数、分数(1/5)、百分数(20%)、带逗号的数字(1,000) |
| """ |
| if not text: |
| return None |
| |
| |
| text = text.strip() |
| clean_text = text.replace(" ", "").replace(",", "").replace("¥", "").replace("$", "") |
| clean_text = clean_text.replace("千克", "").replace("元", "").replace("个", "").replace("只", "") |
| clean_text = clean_text.replace("本", "").replace("米", "").replace("人", "") |
| |
| try: |
| |
| if "%" in clean_text: |
| return float(clean_text.replace("%", "")) / 100 |
| |
| |
| if "/" in clean_text: |
| parts = clean_text.split("/") |
| if len(parts) == 2: |
| try: |
| return float(parts[0]) / float(parts[1]) |
| except: |
| pass |
| |
| |
| if "e" in clean_text.lower() or "E" in clean_text: |
| return float(clean_text) |
| |
| |
| |
| matches = re.findall(r"[-+]?\d*\.\d+|\d+", clean_text) |
| if matches: |
| |
| return float(matches[-1]) |
| |
| except Exception as e: |
| logger.debug(f"解析数字失败: {text}, 错误: {e}") |
| |
| return None |
|
|
| def check_reasoning_quality(self, think_content): |
| """ |
| 检查推理过程的质量 |
| |
| 返回质量评分 (0.0 - 1.0) |
| """ |
| if not think_content: |
| return 0.0 |
| |
| quality_score = 0.0 |
| |
| |
| length = len(think_content) |
| if length >= 100: |
| quality_score += 0.3 |
| elif length >= 50: |
| quality_score += 0.15 |
| |
| |
| keyword_count = sum(1 for kw in self.reasoning_keywords if kw in think_content) |
| quality_score += min(keyword_count * 0.05, 0.3) |
| |
| |
| math_expressions = re.findall(r'\d+\s*[+\-*/×÷=]\s*\d+', think_content) |
| if len(math_expressions) > 0: |
| quality_score += 0.2 |
| if len(math_expressions) >= 3: |
| quality_score += 0.1 |
| |
| |
| has_steps = bool(re.search(r'第\d+步|步骤\d+|^\d+[.、]', think_content, re.MULTILINE)) |
| if has_steps: |
| quality_score += 0.1 |
| |
| return min(quality_score, 1.0) |
|
|
| def compute_reasoning_similarity(self, generated_reasoning, reference_reasoning): |
| """ |
| 计算生成的推理过程与参考推理过程的相似度 |
| |
| 使用序列匹配算法(考虑顺序) |
| 返回相似度分数 (0.0 - 1.0) |
| """ |
| if not generated_reasoning or not reference_reasoning: |
| return 0.0 |
|
|
| similarity = SequenceMatcher(None, generated_reasoning, reference_reasoning).ratio() |
| |
| return similarity |
|
|
| def compute_rewards(self, completions, ground_truths): |
| """ |
| 计算奖励 |
| |
| Args: |
| completions: List[str] 模型生成的完整文本 |
| ground_truths: List[dict] 对应的真值 |
| 必须包含: 'answer_val': float |
| 可选包含: 'reasoning': str, 'reference_completion': str |
| |
| Returns: |
| rewards: List[float] |
| """ |
| rewards = [] |
| |
| for completion, gt in zip(completions, ground_truths): |
| total_reward = 0.0 |
| match = self.format_pattern.search(completion) |
| |
| if match is None: |
| rewards.append(-2.0) |
| continue |
|
|
| think_content = match.group(1).strip() |
| answer_content = match.group(2).strip() |
| total_reward += 0.6 |
| reasoning_quality = self.check_reasoning_quality(think_content) |
| |
| if reasoning_quality < 0.3: |
| total_reward -= 0.5 |
| else: |
| total_reward += reasoning_quality * 1.0 |
|
|
| if self.use_reference_comparison and 'reasoning' in gt: |
| reference_reasoning = gt['reasoning'] |
| similarity = self.compute_reasoning_similarity(think_content, reference_reasoning) |
|
|
| if similarity > 0.3: |
| total_reward += similarity * 0.5 |
| |
| |
| pred_val = self.parse_number(answer_content) |
| gt_val = gt['answer_val'] |
| |
| if pred_val is not None: |
| if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4): |
| total_reward += 3.0 |
| else: |
| try: |
| relative_error = abs(pred_val - gt_val) / (abs(gt_val) + 1e-8) |
| if relative_error < 0.1: |
| total_reward -= 0.3 |
| elif relative_error < 0.5: |
| total_reward -= 0.8 |
| else: |
| total_reward -= 1.5 |
| except: |
| total_reward -= 1.5 |
| else: |
| total_reward -= 1.0 |
| |
| reasoning_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', think_content) |
| if reasoning_numbers and pred_val is not None: |
| answer_in_reasoning = any( |
| math.isclose(float(num), pred_val, rel_tol=1e-3, abs_tol=1e-3) |
| for num in reasoning_numbers |
| ) |
| if answer_in_reasoning: |
| total_reward += 0.2 |
| |
| rewards.append(total_reward) |
| |
| return rewards |
| |
| def compute_metrics(self, completions, ground_truths): |
| metrics = { |
| 'format_correct': 0, |
| 'reasoning_quality_avg': 0.0, |
| 'answer_correct': 0, |
| 'answer_close': 0, |
| 'total': len(completions) |
| } |
| |
| quality_scores = [] |
| |
| for completion, gt in zip(completions, ground_truths): |
| match = self.format_pattern.search(completion) |
| |
| if match: |
| metrics['format_correct'] += 1 |
| |
| think_content = match.group(1).strip() |
| answer_content = match.group(2).strip() |
| |
| quality = self.check_reasoning_quality(think_content) |
| quality_scores.append(quality) |
| pred_val = self.parse_number(answer_content) |
| gt_val = gt['answer_val'] |
| |
| if pred_val is not None and gt_val is not None: |
| if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4): |
| metrics['answer_correct'] += 1 |
| elif math.isclose(pred_val, gt_val, rel_tol=0.1, abs_tol=0.1): |
| metrics['answer_close'] += 1 |
| |
| if quality_scores: |
| metrics['reasoning_quality_avg'] = sum(quality_scores) / len(quality_scores) |
| |
| metrics['format_correct_pct'] = metrics['format_correct'] / metrics['total'] * 100 |
| metrics['answer_correct_pct'] = metrics['answer_correct'] / metrics['total'] * 100 |
| metrics['answer_close_pct'] = metrics['answer_close'] / metrics['total'] * 100 |
| |
| return metrics |