Spaces:
Sleeping
Sleeping
File size: 10,346 Bytes
e783436 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.prompts.reasoning.math import MathAnswerGeneratorPrompt
from dataflow.prompts.reasoning.general import GeneralAnswerGeneratorPrompt
from dataflow.prompts.reasoning.diy import DiyAnswerGeneratorPrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
import pandas as pd
from typing import Union, List, Tuple
import re
import os
@prompt_restrict(
MathAnswerGeneratorPrompt,
GeneralAnswerGeneratorPrompt,
DiyAnswerGeneratorPrompt
)
@OPERATOR_REGISTRY.register()
class VQAReasoningAnswerGenerator(OperatorABC):
'''
Answer Generator is a class that generates answers for given questions.
'''
def __init__(self,
llm_serving: LLMServingABC,
prompt_template: Union[MathAnswerGeneratorPrompt, GeneralAnswerGeneratorPrompt, DiyAnswerGeneratorPrompt, DIYPromptABC] = MathAnswerGeneratorPrompt,
skip_text_only: bool=False,
input_image_default_basedir = "./"
):
self.logger = get_logger()
if prompt_template is None:
prompt_template = MathAnswerGeneratorPrompt()
self.prompts = prompt_template
self.llm_serving = llm_serving
self.skip_text_only = skip_text_only
self.input_image_default_basedir = input_image_default_basedir
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于为给定问题生成答案,调用大语言模型进行推理。\n"
"输入参数:\n"
"- llm_serving:LLM服务实例,用于生成答案\n"
"- prompt_template:提示模板对象,用于构建生成提示词\n"
"输出参数:\n"
"- output_key:生成的答案字段,默认'generated_cot'"
)
elif lang == "en":
return (
"This operator generates answers for given questions using LLMs for reasoning. \n"
"Input Parameters:\n"
"- llm_serving: LLM serving instance for answer generation\n"
"- prompt_template: Prompt template object for constructing generation prompts\n"
"Output Parameters:\n"
"- output_key: Generated answer field, default 'generated_cot'"
)
else:
return "AnswerGenerator produces answers for questions using large language models."
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key, self.input_image_basedir_key, self.input_caption_key, self.input_skip_key]
missing = [k for k in required_keys if k!=None and k not in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
def _prepare_vlm_inputs(self, dataframe) -> Tuple[List[str], List[List[str]], List[List[str]], List[int], List[int]]:
"""
Parses prompts for image markdown, extracts paths and text segments,
and structures them into interleaved lists for the VLM server.
返回:
user_prompts: List[str] (所有的问题)
list_of_image_paths: List[List[str]] (所有请求的绝对路径列表)
list_of_text_segments: List[List[str]] (所有图像标签)
vqa_ids: List[int] (含有图片的问题编号)
unskipped_ids: List[int] (未跳过的问题编号,这里跳过是指保留已经回答的答案,不重新回答)
"""
list_of_image_paths: List[List[str]] = []
list_of_text_segments: List[List[str]] = []
user_prompts:List[str] = []
vqa_ids = []
# Markdown 图片正则匹配: 
markdown_pattern = re.compile(r"!\[(.*?)\]\((.*?)\)")
questions = dataframe[self.input_key].tolist()
unskipped_ids = []
for index, question in enumerate(questions):
# 1. 确定 Base Directory (图像的根目录)
base_dir = self.input_image_default_basedir
if self.input_image_basedir_key in dataframe.columns:
row_base_dir = dataframe.loc[index, self.input_image_basedir_key]
if row_base_dir:
base_dir = row_base_dir
# 2. 准备该请求的结构
current_paths: List[str] = []
current_segments: List[str] = []
current_user_prompt: str = ""
last_end = 0
# 查找所有图片匹配项
matches = list(markdown_pattern.finditer(question))
# 3. 处理纯文本或交错文本/图像
if (not matches):
if (not self.skip_text_only):
# 纯文本提示:直接构建提示并作为唯一的文本片段
if self.input_skip_key != None and self.input_skip_key in dataframe.columns:
if dataframe.loc[index, self.input_skip_key]:
continue
final_prompt_text = self.prompts.build_prompt(question)
# 如果caption key存在,添加caption信息
if self.input_caption_key != None and self.input_caption_key in dataframe.columns:
captions = dataframe.loc[index, self.input_caption_key]
if captions and isinstance(captions, list):
for cap_i, caption in enumerate(captions):
final_prompt_text += f"\n Description of image {cap_i+1}: {caption}"
user_prompts.append(final_prompt_text)
list_of_image_paths.append([])
list_of_text_segments.append([])
unskipped_ids.append(index)
continue
vqa_complete = True
# 4. 遍历匹配项,提取交错的文本片段和图像路径
for match in matches:
leading_text = question[last_end:match.start()].strip()
if leading_text:
current_user_prompt += leading_text
label = match.group(1).strip()
path = match.group(2).strip()
current_segments.append(label)
# 4c. 记录绝对路径 (原始逻辑)
full_path = os.path.join(base_dir, path)
# 检查路径是否存在
if not os.path.isfile(full_path):
self.logger.warning(f"Image file not found: {full_path} (from question index {index})")
vqa_complete = False
break
current_paths.append(full_path)
last_end = match.end()
trailing_text = question[last_end:].strip()
if trailing_text:
current_user_prompt += trailing_text
# 如果caption key存在,添加caption信息
if self.input_caption_key != None and self.input_caption_key in dataframe.columns:
captions = dataframe.loc[index, self.input_caption_key]
if captions and isinstance(captions, list):
for cap_i, caption in enumerate(captions):
current_user_prompt += f"\n Description of image {cap_i+1}: {caption}"
# 5. 存储该请求的结果
if vqa_complete:
vqa_ids.append(index)
if self.input_skip_key != None and self.input_skip_key in dataframe.columns:
if dataframe.loc[index, self.input_skip_key]:
continue
list_of_image_paths.append(current_paths)
list_of_text_segments.append(current_segments)
user_prompts.append(self.prompts.build_prompt(current_user_prompt))
unskipped_ids.append(index)
return user_prompts, list_of_image_paths, list_of_text_segments, vqa_ids, unskipped_ids
def run(
self,
storage,
input_key:str = "instruction",
output_key:str = "generated_cot",
input_caption_key: str | None = None,
input_skip_key: str | None = None,
input_image_basedir_key = "image_basedir",
):
'''
Runs the answer generation process, reading from the input file and saving results to output.
'''
self.input_key, self.output_key = input_key, output_key
self.input_caption_key = input_caption_key
self.input_skip_key = input_skip_key
self.input_image_basedir_key = input_image_basedir_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
# 1. 准备 VLM 输入: 解析 Markdown 并获取路径和文本片段
user_prompts, list_of_image_paths, list_of_image_labels, vqa_ids, unskipped_ids = self._prepare_vlm_inputs(dataframe)
# 2. 获取 System Prompt (假设它存储在 self.prompts 对象中)
# 如果 self.prompts 没有 system_prompt 属性,则使用默认值。
system_prompt = "You are an intelligent chatbot good at college subjects."
answers = self.llm_serving.generate_from_input_multi_images(
list_of_image_paths=list_of_image_paths,
list_of_image_labels=list_of_image_labels,
system_prompt=system_prompt,
user_prompts=user_prompts
)
if self.skip_text_only:
# 只写入vqa_ids指明的行
dataframe = dataframe.loc[vqa_ids].copy() # 注意,这里不重置索引
for idx, ans in zip(unskipped_ids, answers):
dataframe.at[idx, self.output_key] = ans
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [output_key]
|