File size: 14,904 Bytes
e783436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import os
import sys
# 添加父一级目录到 sys.path(上一级)
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)
from dataflow.utils.reasoning.AnswerExtraction import StringCleaner, UnitTextManager, AnswerExtractor
from prompts.bench_evaluate import AnswerJudgePromptQuestion, AnswerJudgeMultipleQuestionsPrompt
from dataflow.core.prompt import DIYPromptABC
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import LLMServingABC
from dataflow.core import OperatorABC

from math_verify import parse, verify
from dataflow import get_logger
from typing import Literal
import pandas as pd
import numpy as np
import time
import re
import json
import json5

@OPERATOR_REGISTRY.register()
class BenchDatasetEvaluatorQuestion(OperatorABC):
    def __init__(self,
                eval_result_path: str = None,
                compare_method: Literal["match", "semantic"] = "match",
                system_prompt: str = "You are a helpful assistant specialized in evaluating answer correctness.",
                llm_serving: LLMServingABC = None,
                prompt_template: DIYPromptABC = None,
                support_subquestions: bool = False,
                skip_true: bool = False, # 是否跳过已经验证过为True的样本
                ):
        
        if eval_result_path is None:
            timestamp = int(time.time())
            eval_result_path = f"result_bencheval/BenchDatasetEvaluator_result_{timestamp}.json"
    
        self.eval_result_path = eval_result_path
        self.compare_method = compare_method
        self.empty_responses_count = 0  # 添加空响应计数器
        
        if compare_method == "match":
            self.compare = self.math_verify_compare
            unit_manager = UnitTextManager()
            string_cleaner = StringCleaner(unit_manager)
            self.answer_extractor = AnswerExtractor(string_cleaner)
        else:
            if prompt_template is None:
                prompt_template = AnswerJudgePromptQuestion() if not support_subquestions else AnswerJudgeMultipleQuestionsPrompt()
            self.prompt_template = prompt_template
            self.system_prompt = system_prompt
            self.llm_serving = llm_serving
            self.support_subquestions = support_subquestions
            self.skip_true = skip_true
            
        self.logger = get_logger()
    
    def math_verify_compare(self, answer, ground_truth):
        try:
            return verify(parse(str(ground_truth)), parse(str(answer)))
        except:
            try:
                return verify(parse(ground_truth), parse(answer))
            except:
                return False

    def ResolveResponse(self, response):
        # 检查空响应
        if not self.support_subquestions:
            if response is None or (isinstance(response, str) and response.strip() == ''):
                self.empty_responses_count += 1
                return False
            try:
                pattern = re.compile(r'"judgement_result"\s*:\s*(true|false)', re.IGNORECASE)
                match = pattern.search(response)
                result_value = None
                if match:
                    result_value = match.group(1).lower()
                else:
                    # 备用解析逻辑,检查响应中是否包含true或false
                    if "true" in response.lower():
                        result_value = "true"
                    else:
                        result_value = "false"
                if result_value == "true":
                    return True
                else:
                    return False
            except Exception as e:
                self.logger.error(f"Response format error: {response}. Error: {e}")
                return False
        
        if self.support_subquestions:
            # 如果支持子问题,假设response是一个列表, 返回正确的数量/总数
            correct_num = 0
            total_num = 0
            try:
                response = json5.loads(response, strict=False)  # 使用json5解析,允许更宽松的格式
                judgement = response.get("judgement", [])
            except Exception as e:
                self.logger.error(f"Response JSON parse error: {response}. Error: {e}")
                self.empty_responses_count += 1
                return "0/0"
            for resp in judgement:
                if isinstance(resp, bool): 
                    if resp is True:
                        correct_num += 1
                        total_num += 1
                    elif resp is False:
                        total_num += 1
                    elif resp.lower() == "empty":
                        continue  # 不计入总数
                elif isinstance(resp, str):
                    if resp.lower() == "true":
                        correct_num += 1
                        total_num += 1
                    elif resp.lower() == "false":
                        total_num += 1
                    elif resp.lower() == "empty":
                        continue  # 不计入总数
                    
            return f"{correct_num}/{total_num}"
            
    @staticmethod
    def get_desc(lang: str = "zh"):
        if lang == "zh":
            return (
                "该算子用于对比预测答案与标准答案的匹配度,支持两种评估模式:\n\n"
                "1. 字符串匹配(match):使用数学验证方法比较答案,适用于有明确答案的问题\n"
                "2. 语义匹配(semantic):使用LLM评估答案的语义相似度,适用于开放性问题\n\n"
                "输入参数:\n"
                "- input_test_answer_key:预测答案字段名\n"
                "- input_gt_answer_key:标准答案字段名\n"
                "- input_question_key:问题字段名(语义匹配模式下必需)\n"
                "- compare_method:比较方法(match/semantic)\n\n"
                "输出参数:\n"
                "- answer_match_result:匹配结果(True/False)\n"
                "- 统计结果将保存到指定的eval_result_path路径\n"
            )
        elif lang == "en":
            return (
                "This operator compares predicted answers against ground truth using two evaluation modes:\n\n"
                "1. String Matching (match): Uses mathematical verification to compare answers, suitable for questions with definitive answers\n"
                "2. Semantic Matching (semantic): Uses LLM to evaluate semantic similarity, suitable for open-ended questions\n\n"
                "Input Parameters:\n"
                "- input_test_answer_key: Predicted answer field\n"
                "- input_gt_answer_key: Ground truth field\n"
                "- input_question_key: Question field (required for semantic mode)\n"
                "- compare_method: Comparison method (match/semantic)\n\n"
                "Output Parameters:\n"
                "- answer_match_result: Matching result (True/False)\n"
                "- Statistics will be saved to the specified eval_result_path\n"
            )
        else:
            return "BenchEvaluator performs answer validation using string matching or semantic comparison"
        
    def check_column(self, required_columns: list[str], dataframe: pd.DataFrame):
        for column in required_columns:
            if column not in dataframe.columns:
                self.logger.error(f"Required column '{column}' not found in dataframe")
                return False
        return True
            
    def statistic(self, file_name_prefix: str, dataframe: pd.DataFrame, compare_method: Literal["match", "semantic"]):
        total_samples = len(dataframe)
        valid_samples = len(dataframe) - self.empty_responses_count
        matched_samples = sum(dataframe['answer_match_result'])
        accuracy = matched_samples / valid_samples if valid_samples > 0 else 0
        
        # 创建统计信息字典
        stats = {
            "bench_name_or_prefix": file_name_prefix,
            "total_samples": total_samples,
            "valid_samples": valid_samples,
            "matched_samples": matched_samples,
            "accuracy": float(accuracy),  # 确保可以被JSON序列化
            "empty_responses_count": self.empty_responses_count,
            "compare_method": compare_method
        }
        
        if self.support_subquestions:
            total_subquestions = dataframe['total_subquestions'].sum()
            correct_subquestions = dataframe['correct_answer_num'].sum()
            subquestion_accuracy = correct_subquestions / total_subquestions if total_subquestions > 0 else 0
            stats.update({
                "total_subquestions": int(total_subquestions),
                "correct_subquestions": int(correct_subquestions),
                "subquestion_accuracy": float(subquestion_accuracy)
            })
        
        # 将字典转换为DataFrame
        stats_df = pd.DataFrame([stats])
        
        # 直接将统计信息写入到self.eval_result_path
        os.makedirs(os.path.dirname(self.eval_result_path), exist_ok=True)
        stats_df.to_json(self.eval_result_path, orient="records", force_ascii=False, indent=2)
        self.logger.success(f"Statistics saved to {self.eval_result_path}")
        
        return stats_df
        
    def run(
            self,
            storage:DataFlowStorage,
            input_test_answer_key: str = "generated_cot",
            input_gt_answer_key: str = "golden_answer",
            input_question_key: str = None,
            ) -> list:

        self.test_answer_key = input_test_answer_key
        self.gt_answer_key = input_gt_answer_key
        self.question_key = input_question_key
        
        dataframe = storage.read("dataframe")
        if 'answer_match_result' not in dataframe.columns:
            dataframe['answer_match_result'] = False
        answers = dataframe[self.test_answer_key]
        ground_truths = dataframe[self.gt_answer_key]
    
        if self.compare_method == "match":
            if self.check_column(
                required_columns=[input_test_answer_key,input_gt_answer_key],
                dataframe=dataframe
            ) is False:
                return required_columns
            
            for i in range(len(answers)):
                final_answer =  self.answer_extractor.extract_answer(answers[i], None)
                if self.compare(final_answer, ground_truths[i]):
                    dataframe.at[i, 'answer_match_result'] = True
                else:
                    dataframe.at[i, 'answer_match_result'] = False
                    
            output_file = storage.write(dataframe)
            
            # 生成统计信息并直接写入JSON文件
            stats = self.statistic(storage.file_name_prefix, dataframe, self.compare_method)
            
            return [self.test_answer_key, self.gt_answer_key, 'answer_match_result']
        else:
            if self.check_column(
                required_columns=[input_test_answer_key,input_gt_answer_key, input_question_key],
                dataframe=dataframe
            ) is False:
                return required_columns
            
            empty_reference_mask = dataframe[input_gt_answer_key].isna() | (dataframe[input_gt_answer_key] == '') 
            if self.skip_true:
                empty_reference_mask = empty_reference_mask | (dataframe['answer_match_result'] == True)
            skipped_rows = dataframe[empty_reference_mask]
            valid_rows = dataframe[~empty_reference_mask]
            skipped_count = len(skipped_rows)
            
            if len(valid_rows) == 0 and not self.skip_true:
                self.logger.warning("No valid samples with reference answers found. All samples skipped.")
                if self.keep_all_samples:
                    output_file = storage.write(dataframe)  # 保留所有行,但answer_match_result都为False
                else:
                    output_file = storage.write(pd.DataFrame(columns=dataframe.columns))  # 不保留任何行
                self.logger.info(f"Dataframe saved to {output_file}. Skipped {skipped_count} samples due to missing reference answers.")
                return required_columns + ['answer_match_result']
            
            # 只对有参考答案的行构建提示词并调用LLM
            inputs = [self.prompt_template.build_prompt(
                question=row[input_question_key],
                answer=row[input_test_answer_key],
                reference_answer=row[input_gt_answer_key]
            ) for _, row in valid_rows.iterrows()]
            
            responses = self.llm_serving.generate_from_input(user_inputs=inputs, system_prompt=self.system_prompt)
            
            # if self.support_subquestions:
            #     # 每个response是一个列表,连接一个长列表,比如[["true", "false"], ["true"]] -> ["true", "false", "true"]
            #     responses = [item for sublist in responses for item in sublist]
            
            results = [self.ResolveResponse(response) for response in responses]
            
            # 创建结果掩码,与valid_rows长度相同
            result_mask = np.array(results, dtype=bool)
            
            # 更新有效行的answer_match_result
            valid_indices = valid_rows.index
            if not self.support_subquestions:
                for i, idx in enumerate(valid_indices):
                    dataframe.at[idx, 'answer_match_result'] = results[i]
            else:
                for i, idx in enumerate(valid_indices):
                    correct_answer_num = int(results[i].split('/')[0])
                    total_subquestions = int(results[i].split('/')[1])
                    dataframe.at[idx, 'correct_answer_num'] = correct_answer_num
                    dataframe.at[idx, 'total_subquestions'] = total_subquestions
                    dataframe.at[idx, 'answer_match_result'] = (correct_answer_num == total_subquestions) and (total_subquestions > 0)   # 全对为True,否则为False
                    dataframe.at[idx, 'response_evaluation'] = responses[i]  # 保存LLM的原始响应内容
                
            output_file = storage.write(dataframe)
            
            # 生成统计信息并直接写入JSON文件
            stats = self.statistic(storage.file_name_prefix, dataframe, self.compare_method)
            
            # 重置空响应计数器
            self.empty_responses_count = 0
            
            return [input_test_answer_key, input_gt_answer_key, input_question_key, 'answer_match_result']