Spaces:
Sleeping
Sleeping
File size: 14,904 Bytes
e783436 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 | import os
import sys
# 添加父一级目录到 sys.path(上一级)
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
from dataflow.utils.reasoning.AnswerExtraction import StringCleaner, UnitTextManager, AnswerExtractor
from prompts.bench_evaluate import AnswerJudgePromptQuestion, AnswerJudgeMultipleQuestionsPrompt
from dataflow.core.prompt import DIYPromptABC
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import LLMServingABC
from dataflow.core import OperatorABC
from math_verify import parse, verify
from dataflow import get_logger
from typing import Literal
import pandas as pd
import numpy as np
import time
import re
import json
import json5
@OPERATOR_REGISTRY.register()
class BenchDatasetEvaluatorQuestion(OperatorABC):
def __init__(self,
eval_result_path: str = None,
compare_method: Literal["match", "semantic"] = "match",
system_prompt: str = "You are a helpful assistant specialized in evaluating answer correctness.",
llm_serving: LLMServingABC = None,
prompt_template: DIYPromptABC = None,
support_subquestions: bool = False,
skip_true: bool = False, # 是否跳过已经验证过为True的样本
):
if eval_result_path is None:
timestamp = int(time.time())
eval_result_path = f"result_bencheval/BenchDatasetEvaluator_result_{timestamp}.json"
self.eval_result_path = eval_result_path
self.compare_method = compare_method
self.empty_responses_count = 0 # 添加空响应计数器
if compare_method == "match":
self.compare = self.math_verify_compare
unit_manager = UnitTextManager()
string_cleaner = StringCleaner(unit_manager)
self.answer_extractor = AnswerExtractor(string_cleaner)
else:
if prompt_template is None:
prompt_template = AnswerJudgePromptQuestion() if not support_subquestions else AnswerJudgeMultipleQuestionsPrompt()
self.prompt_template = prompt_template
self.system_prompt = system_prompt
self.llm_serving = llm_serving
self.support_subquestions = support_subquestions
self.skip_true = skip_true
self.logger = get_logger()
def math_verify_compare(self, answer, ground_truth):
try:
return verify(parse(str(ground_truth)), parse(str(answer)))
except:
try:
return verify(parse(ground_truth), parse(answer))
except:
return False
def ResolveResponse(self, response):
# 检查空响应
if not self.support_subquestions:
if response is None or (isinstance(response, str) and response.strip() == ''):
self.empty_responses_count += 1
return False
try:
pattern = re.compile(r'"judgement_result"\s*:\s*(true|false)', re.IGNORECASE)
match = pattern.search(response)
result_value = None
if match:
result_value = match.group(1).lower()
else:
# 备用解析逻辑,检查响应中是否包含true或false
if "true" in response.lower():
result_value = "true"
else:
result_value = "false"
if result_value == "true":
return True
else:
return False
except Exception as e:
self.logger.error(f"Response format error: {response}. Error: {e}")
return False
if self.support_subquestions:
# 如果支持子问题,假设response是一个列表, 返回正确的数量/总数
correct_num = 0
total_num = 0
try:
response = json5.loads(response, strict=False) # 使用json5解析,允许更宽松的格式
judgement = response.get("judgement", [])
except Exception as e:
self.logger.error(f"Response JSON parse error: {response}. Error: {e}")
self.empty_responses_count += 1
return "0/0"
for resp in judgement:
if isinstance(resp, bool):
if resp is True:
correct_num += 1
total_num += 1
elif resp is False:
total_num += 1
elif resp.lower() == "empty":
continue # 不计入总数
elif isinstance(resp, str):
if resp.lower() == "true":
correct_num += 1
total_num += 1
elif resp.lower() == "false":
total_num += 1
elif resp.lower() == "empty":
continue # 不计入总数
return f"{correct_num}/{total_num}"
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于对比预测答案与标准答案的匹配度,支持两种评估模式:\n\n"
"1. 字符串匹配(match):使用数学验证方法比较答案,适用于有明确答案的问题\n"
"2. 语义匹配(semantic):使用LLM评估答案的语义相似度,适用于开放性问题\n\n"
"输入参数:\n"
"- input_test_answer_key:预测答案字段名\n"
"- input_gt_answer_key:标准答案字段名\n"
"- input_question_key:问题字段名(语义匹配模式下必需)\n"
"- compare_method:比较方法(match/semantic)\n\n"
"输出参数:\n"
"- answer_match_result:匹配结果(True/False)\n"
"- 统计结果将保存到指定的eval_result_path路径\n"
)
elif lang == "en":
return (
"This operator compares predicted answers against ground truth using two evaluation modes:\n\n"
"1. String Matching (match): Uses mathematical verification to compare answers, suitable for questions with definitive answers\n"
"2. Semantic Matching (semantic): Uses LLM to evaluate semantic similarity, suitable for open-ended questions\n\n"
"Input Parameters:\n"
"- input_test_answer_key: Predicted answer field\n"
"- input_gt_answer_key: Ground truth field\n"
"- input_question_key: Question field (required for semantic mode)\n"
"- compare_method: Comparison method (match/semantic)\n\n"
"Output Parameters:\n"
"- answer_match_result: Matching result (True/False)\n"
"- Statistics will be saved to the specified eval_result_path\n"
)
else:
return "BenchEvaluator performs answer validation using string matching or semantic comparison"
def check_column(self, required_columns: list[str], dataframe: pd.DataFrame):
for column in required_columns:
if column not in dataframe.columns:
self.logger.error(f"Required column '{column}' not found in dataframe")
return False
return True
def statistic(self, file_name_prefix: str, dataframe: pd.DataFrame, compare_method: Literal["match", "semantic"]):
total_samples = len(dataframe)
valid_samples = len(dataframe) - self.empty_responses_count
matched_samples = sum(dataframe['answer_match_result'])
accuracy = matched_samples / valid_samples if valid_samples > 0 else 0
# 创建统计信息字典
stats = {
"bench_name_or_prefix": file_name_prefix,
"total_samples": total_samples,
"valid_samples": valid_samples,
"matched_samples": matched_samples,
"accuracy": float(accuracy), # 确保可以被JSON序列化
"empty_responses_count": self.empty_responses_count,
"compare_method": compare_method
}
if self.support_subquestions:
total_subquestions = dataframe['total_subquestions'].sum()
correct_subquestions = dataframe['correct_answer_num'].sum()
subquestion_accuracy = correct_subquestions / total_subquestions if total_subquestions > 0 else 0
stats.update({
"total_subquestions": int(total_subquestions),
"correct_subquestions": int(correct_subquestions),
"subquestion_accuracy": float(subquestion_accuracy)
})
# 将字典转换为DataFrame
stats_df = pd.DataFrame([stats])
# 直接将统计信息写入到self.eval_result_path
os.makedirs(os.path.dirname(self.eval_result_path), exist_ok=True)
stats_df.to_json(self.eval_result_path, orient="records", force_ascii=False, indent=2)
self.logger.success(f"Statistics saved to {self.eval_result_path}")
return stats_df
def run(
self,
storage:DataFlowStorage,
input_test_answer_key: str = "generated_cot",
input_gt_answer_key: str = "golden_answer",
input_question_key: str = None,
) -> list:
self.test_answer_key = input_test_answer_key
self.gt_answer_key = input_gt_answer_key
self.question_key = input_question_key
dataframe = storage.read("dataframe")
if 'answer_match_result' not in dataframe.columns:
dataframe['answer_match_result'] = False
answers = dataframe[self.test_answer_key]
ground_truths = dataframe[self.gt_answer_key]
if self.compare_method == "match":
if self.check_column(
required_columns=[input_test_answer_key,input_gt_answer_key],
dataframe=dataframe
) is False:
return required_columns
for i in range(len(answers)):
final_answer = self.answer_extractor.extract_answer(answers[i], None)
if self.compare(final_answer, ground_truths[i]):
dataframe.at[i, 'answer_match_result'] = True
else:
dataframe.at[i, 'answer_match_result'] = False
output_file = storage.write(dataframe)
# 生成统计信息并直接写入JSON文件
stats = self.statistic(storage.file_name_prefix, dataframe, self.compare_method)
return [self.test_answer_key, self.gt_answer_key, 'answer_match_result']
else:
if self.check_column(
required_columns=[input_test_answer_key,input_gt_answer_key, input_question_key],
dataframe=dataframe
) is False:
return required_columns
empty_reference_mask = dataframe[input_gt_answer_key].isna() | (dataframe[input_gt_answer_key] == '')
if self.skip_true:
empty_reference_mask = empty_reference_mask | (dataframe['answer_match_result'] == True)
skipped_rows = dataframe[empty_reference_mask]
valid_rows = dataframe[~empty_reference_mask]
skipped_count = len(skipped_rows)
if len(valid_rows) == 0 and not self.skip_true:
self.logger.warning("No valid samples with reference answers found. All samples skipped.")
if self.keep_all_samples:
output_file = storage.write(dataframe) # 保留所有行,但answer_match_result都为False
else:
output_file = storage.write(pd.DataFrame(columns=dataframe.columns)) # 不保留任何行
self.logger.info(f"Dataframe saved to {output_file}. Skipped {skipped_count} samples due to missing reference answers.")
return required_columns + ['answer_match_result']
# 只对有参考答案的行构建提示词并调用LLM
inputs = [self.prompt_template.build_prompt(
question=row[input_question_key],
answer=row[input_test_answer_key],
reference_answer=row[input_gt_answer_key]
) for _, row in valid_rows.iterrows()]
responses = self.llm_serving.generate_from_input(user_inputs=inputs, system_prompt=self.system_prompt)
# if self.support_subquestions:
# # 每个response是一个列表,连接一个长列表,比如[["true", "false"], ["true"]] -> ["true", "false", "true"]
# responses = [item for sublist in responses for item in sublist]
results = [self.ResolveResponse(response) for response in responses]
# 创建结果掩码,与valid_rows长度相同
result_mask = np.array(results, dtype=bool)
# 更新有效行的answer_match_result
valid_indices = valid_rows.index
if not self.support_subquestions:
for i, idx in enumerate(valid_indices):
dataframe.at[idx, 'answer_match_result'] = results[i]
else:
for i, idx in enumerate(valid_indices):
correct_answer_num = int(results[i].split('/')[0])
total_subquestions = int(results[i].split('/')[1])
dataframe.at[idx, 'correct_answer_num'] = correct_answer_num
dataframe.at[idx, 'total_subquestions'] = total_subquestions
dataframe.at[idx, 'answer_match_result'] = (correct_answer_num == total_subquestions) and (total_subquestions > 0) # 全对为True,否则为False
dataframe.at[idx, 'response_evaluation'] = responses[i] # 保存LLM的原始响应内容
output_file = storage.write(dataframe)
# 生成统计信息并直接写入JSON文件
stats = self.statistic(storage.file_name_prefix, dataframe, self.compare_method)
# 重置空响应计数器
self.empty_responses_count = 0
return [input_test_answer_key, input_gt_answer_key, input_question_key, 'answer_match_result']
|