DataFlow-VQA / operators /bench_evaluate.py
aaron1141's picture
initial hf spaces demo
e783436
import os
import sys
# 添加父一级目录到 sys.path(上一级)
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
from dataflow.utils.reasoning.AnswerExtraction import StringCleaner, UnitTextManager, AnswerExtractor
from prompts.bench_evaluate import AnswerJudgePromptQuestion, AnswerJudgeMultipleQuestionsPrompt
from dataflow.core.prompt import DIYPromptABC
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import LLMServingABC
from dataflow.core import OperatorABC
from math_verify import parse, verify
from dataflow import get_logger
from typing import Literal
import pandas as pd
import numpy as np
import time
import re
import json
import json5
@OPERATOR_REGISTRY.register()
class BenchDatasetEvaluatorQuestion(OperatorABC):
def __init__(self,
eval_result_path: str = None,
compare_method: Literal["match", "semantic"] = "match",
system_prompt: str = "You are a helpful assistant specialized in evaluating answer correctness.",
llm_serving: LLMServingABC = None,
prompt_template: DIYPromptABC = None,
support_subquestions: bool = False,
skip_true: bool = False, # 是否跳过已经验证过为True的样本
):
if eval_result_path is None:
timestamp = int(time.time())
eval_result_path = f"result_bencheval/BenchDatasetEvaluator_result_{timestamp}.json"
self.eval_result_path = eval_result_path
self.compare_method = compare_method
self.empty_responses_count = 0 # 添加空响应计数器
if compare_method == "match":
self.compare = self.math_verify_compare
unit_manager = UnitTextManager()
string_cleaner = StringCleaner(unit_manager)
self.answer_extractor = AnswerExtractor(string_cleaner)
else:
if prompt_template is None:
prompt_template = AnswerJudgePromptQuestion() if not support_subquestions else AnswerJudgeMultipleQuestionsPrompt()
self.prompt_template = prompt_template
self.system_prompt = system_prompt
self.llm_serving = llm_serving
self.support_subquestions = support_subquestions
self.skip_true = skip_true
self.logger = get_logger()
def math_verify_compare(self, answer, ground_truth):
try:
return verify(parse(str(ground_truth)), parse(str(answer)))
except:
try:
return verify(parse(ground_truth), parse(answer))
except:
return False
def ResolveResponse(self, response):
# 检查空响应
if not self.support_subquestions:
if response is None or (isinstance(response, str) and response.strip() == ''):
self.empty_responses_count += 1
return False
try:
pattern = re.compile(r'"judgement_result"\s*:\s*(true|false)', re.IGNORECASE)
match = pattern.search(response)
result_value = None
if match:
result_value = match.group(1).lower()
else:
# 备用解析逻辑,检查响应中是否包含true或false
if "true" in response.lower():
result_value = "true"
else:
result_value = "false"
if result_value == "true":
return True
else:
return False
except Exception as e:
self.logger.error(f"Response format error: {response}. Error: {e}")
return False
if self.support_subquestions:
# 如果支持子问题,假设response是一个列表, 返回正确的数量/总数
correct_num = 0
total_num = 0
try:
response = json5.loads(response, strict=False) # 使用json5解析,允许更宽松的格式
judgement = response.get("judgement", [])
except Exception as e:
self.logger.error(f"Response JSON parse error: {response}. Error: {e}")
self.empty_responses_count += 1
return "0/0"
for resp in judgement:
if isinstance(resp, bool):
if resp is True:
correct_num += 1
total_num += 1
elif resp is False:
total_num += 1
elif resp.lower() == "empty":
continue # 不计入总数
elif isinstance(resp, str):
if resp.lower() == "true":
correct_num += 1
total_num += 1
elif resp.lower() == "false":
total_num += 1
elif resp.lower() == "empty":
continue # 不计入总数
return f"{correct_num}/{total_num}"
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于对比预测答案与标准答案的匹配度,支持两种评估模式:\n\n"
"1. 字符串匹配(match):使用数学验证方法比较答案,适用于有明确答案的问题\n"
"2. 语义匹配(semantic):使用LLM评估答案的语义相似度,适用于开放性问题\n\n"
"输入参数:\n"
"- input_test_answer_key:预测答案字段名\n"
"- input_gt_answer_key:标准答案字段名\n"
"- input_question_key:问题字段名(语义匹配模式下必需)\n"
"- compare_method:比较方法(match/semantic)\n\n"
"输出参数:\n"
"- answer_match_result:匹配结果(True/False)\n"
"- 统计结果将保存到指定的eval_result_path路径\n"
)
elif lang == "en":
return (
"This operator compares predicted answers against ground truth using two evaluation modes:\n\n"
"1. String Matching (match): Uses mathematical verification to compare answers, suitable for questions with definitive answers\n"
"2. Semantic Matching (semantic): Uses LLM to evaluate semantic similarity, suitable for open-ended questions\n\n"
"Input Parameters:\n"
"- input_test_answer_key: Predicted answer field\n"
"- input_gt_answer_key: Ground truth field\n"
"- input_question_key: Question field (required for semantic mode)\n"
"- compare_method: Comparison method (match/semantic)\n\n"
"Output Parameters:\n"
"- answer_match_result: Matching result (True/False)\n"
"- Statistics will be saved to the specified eval_result_path\n"
)
else:
return "BenchEvaluator performs answer validation using string matching or semantic comparison"
def check_column(self, required_columns: list[str], dataframe: pd.DataFrame):
for column in required_columns:
if column not in dataframe.columns:
self.logger.error(f"Required column '{column}' not found in dataframe")
return False
return True
def statistic(self, file_name_prefix: str, dataframe: pd.DataFrame, compare_method: Literal["match", "semantic"]):
total_samples = len(dataframe)
valid_samples = len(dataframe) - self.empty_responses_count
matched_samples = sum(dataframe['answer_match_result'])
accuracy = matched_samples / valid_samples if valid_samples > 0 else 0
# 创建统计信息字典
stats = {
"bench_name_or_prefix": file_name_prefix,
"total_samples": total_samples,
"valid_samples": valid_samples,
"matched_samples": matched_samples,
"accuracy": float(accuracy), # 确保可以被JSON序列化
"empty_responses_count": self.empty_responses_count,
"compare_method": compare_method
}
if self.support_subquestions:
total_subquestions = dataframe['total_subquestions'].sum()
correct_subquestions = dataframe['correct_answer_num'].sum()
subquestion_accuracy = correct_subquestions / total_subquestions if total_subquestions > 0 else 0
stats.update({
"total_subquestions": int(total_subquestions),
"correct_subquestions": int(correct_subquestions),
"subquestion_accuracy": float(subquestion_accuracy)
})
# 将字典转换为DataFrame
stats_df = pd.DataFrame([stats])
# 直接将统计信息写入到self.eval_result_path
os.makedirs(os.path.dirname(self.eval_result_path), exist_ok=True)
stats_df.to_json(self.eval_result_path, orient="records", force_ascii=False, indent=2)
self.logger.success(f"Statistics saved to {self.eval_result_path}")
return stats_df
def run(
self,
storage:DataFlowStorage,
input_test_answer_key: str = "generated_cot",
input_gt_answer_key: str = "golden_answer",
input_question_key: str = None,
) -> list:
self.test_answer_key = input_test_answer_key
self.gt_answer_key = input_gt_answer_key
self.question_key = input_question_key
dataframe = storage.read("dataframe")
if 'answer_match_result' not in dataframe.columns:
dataframe['answer_match_result'] = False
answers = dataframe[self.test_answer_key]
ground_truths = dataframe[self.gt_answer_key]
if self.compare_method == "match":
if self.check_column(
required_columns=[input_test_answer_key,input_gt_answer_key],
dataframe=dataframe
) is False:
return required_columns
for i in range(len(answers)):
final_answer = self.answer_extractor.extract_answer(answers[i], None)
if self.compare(final_answer, ground_truths[i]):
dataframe.at[i, 'answer_match_result'] = True
else:
dataframe.at[i, 'answer_match_result'] = False
output_file = storage.write(dataframe)
# 生成统计信息并直接写入JSON文件
stats = self.statistic(storage.file_name_prefix, dataframe, self.compare_method)
return [self.test_answer_key, self.gt_answer_key, 'answer_match_result']
else:
if self.check_column(
required_columns=[input_test_answer_key,input_gt_answer_key, input_question_key],
dataframe=dataframe
) is False:
return required_columns
empty_reference_mask = dataframe[input_gt_answer_key].isna() | (dataframe[input_gt_answer_key] == '')
if self.skip_true:
empty_reference_mask = empty_reference_mask | (dataframe['answer_match_result'] == True)
skipped_rows = dataframe[empty_reference_mask]
valid_rows = dataframe[~empty_reference_mask]
skipped_count = len(skipped_rows)
if len(valid_rows) == 0 and not self.skip_true:
self.logger.warning("No valid samples with reference answers found. All samples skipped.")
if self.keep_all_samples:
output_file = storage.write(dataframe) # 保留所有行,但answer_match_result都为False
else:
output_file = storage.write(pd.DataFrame(columns=dataframe.columns)) # 不保留任何行
self.logger.info(f"Dataframe saved to {output_file}. Skipped {skipped_count} samples due to missing reference answers.")
return required_columns + ['answer_match_result']
# 只对有参考答案的行构建提示词并调用LLM
inputs = [self.prompt_template.build_prompt(
question=row[input_question_key],
answer=row[input_test_answer_key],
reference_answer=row[input_gt_answer_key]
) for _, row in valid_rows.iterrows()]
responses = self.llm_serving.generate_from_input(user_inputs=inputs, system_prompt=self.system_prompt)
# if self.support_subquestions:
# # 每个response是一个列表,连接一个长列表,比如[["true", "false"], ["true"]] -> ["true", "false", "true"]
# responses = [item for sublist in responses for item in sublist]
results = [self.ResolveResponse(response) for response in responses]
# 创建结果掩码,与valid_rows长度相同
result_mask = np.array(results, dtype=bool)
# 更新有效行的answer_match_result
valid_indices = valid_rows.index
if not self.support_subquestions:
for i, idx in enumerate(valid_indices):
dataframe.at[idx, 'answer_match_result'] = results[i]
else:
for i, idx in enumerate(valid_indices):
correct_answer_num = int(results[i].split('/')[0])
total_subquestions = int(results[i].split('/')[1])
dataframe.at[idx, 'correct_answer_num'] = correct_answer_num
dataframe.at[idx, 'total_subquestions'] = total_subquestions
dataframe.at[idx, 'answer_match_result'] = (correct_answer_num == total_subquestions) and (total_subquestions > 0) # 全对为True,否则为False
dataframe.at[idx, 'response_evaluation'] = responses[i] # 保存LLM的原始响应内容
output_file = storage.write(dataframe)
# 生成统计信息并直接写入JSON文件
stats = self.statistic(storage.file_name_prefix, dataframe, self.compare_method)
# 重置空响应计数器
self.empty_responses_count = 0
return [input_test_answer_key, input_gt_answer_key, input_question_key, 'answer_match_result']