DataFlow-VQA / operators /answer_extractor.py
aaron1141's picture
initial hf spaces demo
e783436
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
import pandas as pd
from typing import Union
@OPERATOR_REGISTRY.register()
class AnswerExtractionOperator(OperatorABC):
def __init__(self, llm_serving: Union[None, object] = None, overwrite: bool = False):
self.logger = get_logger()
self.llm_serving = llm_serving
self.overwrite = overwrite
self.system_prompt = "You are a professional question answering system. You will be given a question with corresponding solution. Extract a concise and accurate answer from the provided solution. Output only the answer without any additional text."
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于从解答中提取答案,读取解答字段并调用LLM提取答案。"
"输入参数:\n"
"- input_solution_key:解答字段名,默认为'solution'\n"
"- output_key:答案字段名,默认为'answer'\n"
"- overwrite:是否覆盖已有答案,默认为False\n"
"输出参数:\n"
"- output_key:提取的答案"
)
elif lang == "en":
return (
"This operator extracts answers from solutions, reading from the solution field and using LLM to extract answers."
"Input Parameters:\n"
"- input_solution_key: Solution field name, default 'solution'\n"
"- output_key: Answer field name, default 'answer'\n"
"- overwrite: Whether to overwrite existing answers, default False\n"
"Output Parameters:\n"
"- output_key: Extracted answer"
)
else:
return "AnswerExtractionOperator extracts answers from solutions using LLM."
def run(self, storage: DataFlowStorage, input_question_key: str = "question", input_solution_key: str = "solution", output_key: str = "answer"):
dataframe = storage.read("dataframe")
if input_solution_key not in dataframe.columns:
raise ValueError(f"input_solution_key: {input_solution_key} not found in dataframe columns.")
# -----------------------------------
# 一套统一的空值判断逻辑
# -----------------------------------
def _is_valid(x, *, empty_ok=False):
"""
empty_ok=False: 用来判断 solution 是否有效(空白→无效)
empty_ok=True: 用来判断 output 是否“空”(空白→空)
"""
if x is None:
return empty_ok
if isinstance(x, float) and pd.isna(x):
return empty_ok
if isinstance(x, str) and x.strip() == "":
return empty_ok
return not empty_ok
# solution 必须有效(非空、非空白)
mask = dataframe[input_solution_key].apply(lambda x: _is_valid(x, empty_ok=False))
# 若 overwrite=False,则 output_key 为空(空白也算)才处理
if not self.overwrite and output_key in dataframe.columns:
mask = mask & dataframe[output_key].apply(lambda x: _is_valid(x, empty_ok=True))
# 收集有效 solutions
solutions = dataframe.loc[mask, input_solution_key].tolist()
questions = dataframe.loc[mask, input_question_key].tolist()
# 调用 LLM
if self.llm_serving:
prompts = [
self.system_prompt + f"\n\nQuestion: {q}\nSolution: {s}\nNow extract the answer."
for q, s in zip(questions, solutions)
]
answers = self.llm_serving.generate_from_input(prompts)
else:
answers = solutions
# 写回
dataframe.loc[mask, output_key] = answers
output_file = storage.write(dataframe)
self.logger.info(f"Extracted answers saved to {output_file}")
return [output_key]