Spaces:
Running
Running
| from dataflow.utils.registry import OPERATOR_REGISTRY | |
| from dataflow import get_logger | |
| from dataflow.core import OperatorABC | |
| from dataflow.utils.storage import DataFlowStorage | |
| import pandas as pd | |
| from typing import Union | |
| class AnswerExtractionOperator(OperatorABC): | |
| def __init__(self, llm_serving: Union[None, object] = None, overwrite: bool = False): | |
| self.logger = get_logger() | |
| self.llm_serving = llm_serving | |
| self.overwrite = overwrite | |
| self.system_prompt = "You are a professional question answering system. You will be given a question with corresponding solution. Extract a concise and accurate answer from the provided solution. Output only the answer without any additional text." | |
| def get_desc(lang: str = "zh"): | |
| if lang == "zh": | |
| return ( | |
| "该算子用于从解答中提取答案,读取解答字段并调用LLM提取答案。" | |
| "输入参数:\n" | |
| "- input_solution_key:解答字段名,默认为'solution'\n" | |
| "- output_key:答案字段名,默认为'answer'\n" | |
| "- overwrite:是否覆盖已有答案,默认为False\n" | |
| "输出参数:\n" | |
| "- output_key:提取的答案" | |
| ) | |
| elif lang == "en": | |
| return ( | |
| "This operator extracts answers from solutions, reading from the solution field and using LLM to extract answers." | |
| "Input Parameters:\n" | |
| "- input_solution_key: Solution field name, default 'solution'\n" | |
| "- output_key: Answer field name, default 'answer'\n" | |
| "- overwrite: Whether to overwrite existing answers, default False\n" | |
| "Output Parameters:\n" | |
| "- output_key: Extracted answer" | |
| ) | |
| else: | |
| return "AnswerExtractionOperator extracts answers from solutions using LLM." | |
| def run(self, storage: DataFlowStorage, input_question_key: str = "question", input_solution_key: str = "solution", output_key: str = "answer"): | |
| dataframe = storage.read("dataframe") | |
| if input_solution_key not in dataframe.columns: | |
| raise ValueError(f"input_solution_key: {input_solution_key} not found in dataframe columns.") | |
| # ----------------------------------- | |
| # 一套统一的空值判断逻辑 | |
| # ----------------------------------- | |
| def _is_valid(x, *, empty_ok=False): | |
| """ | |
| empty_ok=False: 用来判断 solution 是否有效(空白→无效) | |
| empty_ok=True: 用来判断 output 是否“空”(空白→空) | |
| """ | |
| if x is None: | |
| return empty_ok | |
| if isinstance(x, float) and pd.isna(x): | |
| return empty_ok | |
| if isinstance(x, str) and x.strip() == "": | |
| return empty_ok | |
| return not empty_ok | |
| # solution 必须有效(非空、非空白) | |
| mask = dataframe[input_solution_key].apply(lambda x: _is_valid(x, empty_ok=False)) | |
| # 若 overwrite=False,则 output_key 为空(空白也算)才处理 | |
| if not self.overwrite and output_key in dataframe.columns: | |
| mask = mask & dataframe[output_key].apply(lambda x: _is_valid(x, empty_ok=True)) | |
| # 收集有效 solutions | |
| solutions = dataframe.loc[mask, input_solution_key].tolist() | |
| questions = dataframe.loc[mask, input_question_key].tolist() | |
| # 调用 LLM | |
| if self.llm_serving: | |
| prompts = [ | |
| self.system_prompt + f"\n\nQuestion: {q}\nSolution: {s}\nNow extract the answer." | |
| for q, s in zip(questions, solutions) | |
| ] | |
| answers = self.llm_serving.generate_from_input(prompts) | |
| else: | |
| answers = solutions | |
| # 写回 | |
| dataframe.loc[mask, output_key] = answers | |
| output_file = storage.write(dataframe) | |
| self.logger.info(f"Extracted answers saved to {output_file}") | |
| return [output_key] |