from dataflow.utils.registry import OPERATOR_REGISTRY from dataflow import get_logger from dataflow.utils.storage import DataFlowStorage from dataflow.core import OperatorABC from dataflow.core import LLMServingABC from dataflow.prompts.reasoning.math import MathAnswerGeneratorPrompt from dataflow.prompts.reasoning.general import GeneralAnswerGeneratorPrompt from dataflow.prompts.reasoning.diy import DiyAnswerGeneratorPrompt from dataflow.core.prompt import prompt_restrict, DIYPromptABC import pandas as pd from typing import Union, List, Tuple import re import os @prompt_restrict( MathAnswerGeneratorPrompt, GeneralAnswerGeneratorPrompt, DiyAnswerGeneratorPrompt ) @OPERATOR_REGISTRY.register() class VQAReasoningAnswerGenerator(OperatorABC): ''' Answer Generator is a class that generates answers for given questions. ''' def __init__(self, llm_serving: LLMServingABC, prompt_template: Union[MathAnswerGeneratorPrompt, GeneralAnswerGeneratorPrompt, DiyAnswerGeneratorPrompt, DIYPromptABC] = MathAnswerGeneratorPrompt, skip_text_only: bool=False, input_image_default_basedir = "./" ): self.logger = get_logger() if prompt_template is None: prompt_template = MathAnswerGeneratorPrompt() self.prompts = prompt_template self.llm_serving = llm_serving self.skip_text_only = skip_text_only self.input_image_default_basedir = input_image_default_basedir @staticmethod def get_desc(lang: str = "zh"): if lang == "zh": return ( "该算子用于为给定问题生成答案,调用大语言模型进行推理。\n" "输入参数:\n" "- llm_serving:LLM服务实例,用于生成答案\n" "- prompt_template:提示模板对象,用于构建生成提示词\n" "输出参数:\n" "- output_key:生成的答案字段,默认'generated_cot'" ) elif lang == "en": return ( "This operator generates answers for given questions using LLMs for reasoning. \n" "Input Parameters:\n" "- llm_serving: LLM serving instance for answer generation\n" "- prompt_template: Prompt template object for constructing generation prompts\n" "Output Parameters:\n" "- output_key: Generated answer field, default 'generated_cot'" ) else: return "AnswerGenerator produces answers for questions using large language models." def _validate_dataframe(self, dataframe: pd.DataFrame): required_keys = [self.input_key, self.input_image_basedir_key, self.input_caption_key, self.input_skip_key] missing = [k for k in required_keys if k!=None and k not in dataframe.columns] if missing: raise ValueError(f"Missing required column(s): {missing}") def _prepare_vlm_inputs(self, dataframe) -> Tuple[List[str], List[List[str]], List[List[str]], List[int], List[int]]: """ Parses prompts for image markdown, extracts paths and text segments, and structures them into interleaved lists for the VLM server. 返回: user_prompts: List[str] (所有的问题) list_of_image_paths: List[List[str]] (所有请求的绝对路径列表) list_of_text_segments: List[List[str]] (所有图像标签) vqa_ids: List[int] (含有图片的问题编号) unskipped_ids: List[int] (未跳过的问题编号,这里跳过是指保留已经回答的答案,不重新回答) """ list_of_image_paths: List[List[str]] = [] list_of_text_segments: List[List[str]] = [] user_prompts:List[str] = [] vqa_ids = [] # Markdown 图片正则匹配: ![label](path) markdown_pattern = re.compile(r"!\[(.*?)\]\((.*?)\)") questions = dataframe[self.input_key].tolist() unskipped_ids = [] for index, question in enumerate(questions): # 1. 确定 Base Directory (图像的根目录) base_dir = self.input_image_default_basedir if self.input_image_basedir_key in dataframe.columns: row_base_dir = dataframe.loc[index, self.input_image_basedir_key] if row_base_dir: base_dir = row_base_dir # 2. 准备该请求的结构 current_paths: List[str] = [] current_segments: List[str] = [] current_user_prompt: str = "" last_end = 0 # 查找所有图片匹配项 matches = list(markdown_pattern.finditer(question)) # 3. 处理纯文本或交错文本/图像 if (not matches): if (not self.skip_text_only): # 纯文本提示:直接构建提示并作为唯一的文本片段 if self.input_skip_key != None and self.input_skip_key in dataframe.columns: if dataframe.loc[index, self.input_skip_key]: continue final_prompt_text = self.prompts.build_prompt(question) # 如果caption key存在,添加caption信息 if self.input_caption_key != None and self.input_caption_key in dataframe.columns: captions = dataframe.loc[index, self.input_caption_key] if captions and isinstance(captions, list): for cap_i, caption in enumerate(captions): final_prompt_text += f"\n Description of image {cap_i+1}: {caption}" user_prompts.append(final_prompt_text) list_of_image_paths.append([]) list_of_text_segments.append([]) unskipped_ids.append(index) continue vqa_complete = True # 4. 遍历匹配项,提取交错的文本片段和图像路径 for match in matches: leading_text = question[last_end:match.start()].strip() if leading_text: current_user_prompt += leading_text label = match.group(1).strip() path = match.group(2).strip() current_segments.append(label) # 4c. 记录绝对路径 (原始逻辑) full_path = os.path.join(base_dir, path) # 检查路径是否存在 if not os.path.isfile(full_path): self.logger.warning(f"Image file not found: {full_path} (from question index {index})") vqa_complete = False break current_paths.append(full_path) last_end = match.end() trailing_text = question[last_end:].strip() if trailing_text: current_user_prompt += trailing_text # 如果caption key存在,添加caption信息 if self.input_caption_key != None and self.input_caption_key in dataframe.columns: captions = dataframe.loc[index, self.input_caption_key] if captions and isinstance(captions, list): for cap_i, caption in enumerate(captions): current_user_prompt += f"\n Description of image {cap_i+1}: {caption}" # 5. 存储该请求的结果 if vqa_complete: vqa_ids.append(index) if self.input_skip_key != None and self.input_skip_key in dataframe.columns: if dataframe.loc[index, self.input_skip_key]: continue list_of_image_paths.append(current_paths) list_of_text_segments.append(current_segments) user_prompts.append(self.prompts.build_prompt(current_user_prompt)) unskipped_ids.append(index) return user_prompts, list_of_image_paths, list_of_text_segments, vqa_ids, unskipped_ids def run( self, storage, input_key:str = "instruction", output_key:str = "generated_cot", input_caption_key: str | None = None, input_skip_key: str | None = None, input_image_basedir_key = "image_basedir", ): ''' Runs the answer generation process, reading from the input file and saving results to output. ''' self.input_key, self.output_key = input_key, output_key self.input_caption_key = input_caption_key self.input_skip_key = input_skip_key self.input_image_basedir_key = input_image_basedir_key dataframe = storage.read("dataframe") self._validate_dataframe(dataframe) # 1. 准备 VLM 输入: 解析 Markdown 并获取路径和文本片段 user_prompts, list_of_image_paths, list_of_image_labels, vqa_ids, unskipped_ids = self._prepare_vlm_inputs(dataframe) # 2. 获取 System Prompt (假设它存储在 self.prompts 对象中) # 如果 self.prompts 没有 system_prompt 属性,则使用默认值。 system_prompt = "You are an intelligent chatbot good at college subjects." answers = self.llm_serving.generate_from_input_multi_images( list_of_image_paths=list_of_image_paths, list_of_image_labels=list_of_image_labels, system_prompt=system_prompt, user_prompts=user_prompts ) if self.skip_text_only: # 只写入vqa_ids指明的行 dataframe = dataframe.loc[vqa_ids].copy() # 注意,这里不重置索引 for idx, ans in zip(unskipped_ids, answers): dataframe.at[idx, self.output_key] = ans output_file = storage.write(dataframe) self.logger.info(f"Results saved to {output_file}") return [output_key]