File size: 8,702 Bytes
263d741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68004aa
263d741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68004aa
263d741
 
 
 
 
 
68004aa
263d741
 
 
68004aa
263d741
 
 
 
 
 
68004aa
263d741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68004aa
263d741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68004aa
263d741
 
 
 
 
 
 
 
68004aa
 
263d741
 
 
68004aa
263d741
 
 
68004aa
263d741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68004aa
263d741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import re
import math
import logging
from difflib import SequenceMatcher

logger = logging.getLogger(__name__)

class MathReward:
    def __init__(self, use_reference_comparison=True):
        """
        Args:
            use_reference_comparison: 是否使用参考答案进行推理过程比较
        """
        self.format_pattern = re.compile(r"<think>(.*?)</think>\s*<answer>(.*?)</answer>", re.DOTALL)
        self.use_reference_comparison = use_reference_comparison
        
        # 推理关键词(用于检查推理质量)
        self.reasoning_keywords = [
            '计算', '因为', '所以', '首先', '然后', '接着', '最后', '根据',
            '第一步', '第二步', '第三步', '第', '步', '得到', '等于', 
            '加', '减', '乘', '除', '=', '+', '-', '*', '/', '÷', '×'
        ]
        
    def parse_number(self, text):
        """
        从文本中解析数值。
        支持:整数、小数、分数(1/5)、百分数(20%)、带逗号的数字(1,000)
        """
        if not text:
            return None
        
        # 预处理:移除空格、货币符号、常见的中文单位
        text = text.strip()
        clean_text = text.replace(" ", "").replace(",", "").replace("¥", "").replace("$", "")
        clean_text = clean_text.replace("千克", "").replace("元", "").replace("个", "").replace("只", "")
        clean_text = clean_text.replace("本", "").replace("米", "").replace("人", "")
        
        try:
            # 1. 处理百分数 (e.g., "20%")
            if "%" in clean_text:
                return float(clean_text.replace("%", "")) / 100
            
            # 2. 处理分数 (e.g., "1/5" 或 "42/5")
            if "/" in clean_text:
                parts = clean_text.split("/")
                if len(parts) == 2:
                    try:
                        return float(parts[0]) / float(parts[1])
                    except:
                        pass
            
            # 3. 处理科学记数法 (e.g., "1.5e-3")
            if "e" in clean_text.lower() or "E" in clean_text:
                return float(clean_text)
            
            # 4. 提取所有匹配的数字格式
            # 匹配 浮点数 或 整数,忽略可能混杂的文字
            matches = re.findall(r"[-+]?\d*\.\d+|\d+", clean_text)
            if matches:
                # 取最后一个作为最终答案
                return float(matches[-1])
                
        except Exception as e:
            logger.debug(f"解析数字失败: {text}, 错误: {e}")
            
        return None

    def check_reasoning_quality(self, think_content):
        """
        检查推理过程的质量
        
        返回质量评分 (0.0 - 1.0)
        """
        if not think_content:
            return 0.0
        
        quality_score = 0.0
        
        # 长度检查
        length = len(think_content)
        if length >= 100:
            quality_score += 0.3
        elif length >= 50:
            quality_score += 0.15
        
        # 关键词检查
        keyword_count = sum(1 for kw in self.reasoning_keywords if kw in think_content)
        quality_score += min(keyword_count * 0.05, 0.3)
        
        # 数学表达式检查
        math_expressions = re.findall(r'\d+\s*[+\-*/×÷=]\s*\d+', think_content)
        if len(math_expressions) > 0:
            quality_score += 0.2
            if len(math_expressions) >= 3:
                quality_score += 0.1
        
        # 结构检查
        has_steps = bool(re.search(r'第\d+步|步骤\d+|^\d+[.、]', think_content, re.MULTILINE))
        if has_steps:
            quality_score += 0.1
        
        return min(quality_score, 1.0)

    def compute_reasoning_similarity(self, generated_reasoning, reference_reasoning):
        """
        计算生成的推理过程与参考推理过程的相似度
        
        使用序列匹配算法(考虑顺序)
        返回相似度分数 (0.0 - 1.0)
        """
        if not generated_reasoning or not reference_reasoning:
            return 0.0

        similarity = SequenceMatcher(None, generated_reasoning, reference_reasoning).ratio()
        
        return similarity

    def compute_rewards(self, completions, ground_truths):
        """
        计算奖励
        
        Args:
            completions: List[str] 模型生成的完整文本
            ground_truths: List[dict] 对应的真值 
                必须包含: 'answer_val': float
                可选包含: 'reasoning': str, 'reference_completion': str
        
        Returns:
            rewards: List[float]
        """
        rewards = []
        
        for completion, gt in zip(completions, ground_truths):
            total_reward = 0.0
            match = self.format_pattern.search(completion)
            
            if match is None:
                rewards.append(-2.0)
                continue

            think_content = match.group(1).strip()
            answer_content = match.group(2).strip()
            total_reward += 0.6
            reasoning_quality = self.check_reasoning_quality(think_content)
            
            if reasoning_quality < 0.3:
                total_reward -= 0.5
            else:
                total_reward += reasoning_quality * 1.0  

            if self.use_reference_comparison and 'reasoning' in gt:
                reference_reasoning = gt['reasoning']
                similarity = self.compute_reasoning_similarity(think_content, reference_reasoning)

                if similarity > 0.3:
                    total_reward += similarity * 0.5
            
            #答案准确性检查---
            pred_val = self.parse_number(answer_content)
            gt_val = gt['answer_val']
            
            if pred_val is not None:
                if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
                    total_reward += 3.0
                else:
                    try:
                        relative_error = abs(pred_val - gt_val) / (abs(gt_val) + 1e-8)
                        if relative_error < 0.1:
                            total_reward -= 0.3
                        elif relative_error < 0.5:
                            total_reward -= 0.8
                        else:
                            total_reward -= 1.5
                    except:
                        total_reward -= 1.5
            else:
                total_reward -= 1.0
            
            reasoning_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', think_content)
            if reasoning_numbers and pred_val is not None:
                answer_in_reasoning = any(
                    math.isclose(float(num), pred_val, rel_tol=1e-3, abs_tol=1e-3)
                    for num in reasoning_numbers
                )
                if answer_in_reasoning:
                    total_reward += 0.2
            
            rewards.append(total_reward)
        
        return rewards
    
    def compute_metrics(self, completions, ground_truths):
        metrics = {
            'format_correct': 0,
            'reasoning_quality_avg': 0.0,
            'answer_correct': 0,
            'answer_close': 0,  
            'total': len(completions)
        }
        
        quality_scores = []
        
        for completion, gt in zip(completions, ground_truths):
            match = self.format_pattern.search(completion)
            
            if match:
                metrics['format_correct'] += 1
                
                think_content = match.group(1).strip()
                answer_content = match.group(2).strip()
                
                quality = self.check_reasoning_quality(think_content)
                quality_scores.append(quality)
                pred_val = self.parse_number(answer_content)
                gt_val = gt['answer_val']
                
                if pred_val is not None and gt_val is not None:
                    if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
                        metrics['answer_correct'] += 1
                    elif math.isclose(pred_val, gt_val, rel_tol=0.1, abs_tol=0.1):
                        metrics['answer_close'] += 1
        
        if quality_scores:
            metrics['reasoning_quality_avg'] = sum(quality_scores) / len(quality_scores)
        
        metrics['format_correct_pct'] = metrics['format_correct'] / metrics['total'] * 100
        metrics['answer_correct_pct'] = metrics['answer_correct'] / metrics['total'] * 100
        metrics['answer_close_pct'] = metrics['answer_close'] / metrics['total'] * 100
        
        return metrics