File size: 10,346 Bytes
e783436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC

from dataflow.prompts.reasoning.math import MathAnswerGeneratorPrompt
from dataflow.prompts.reasoning.general import GeneralAnswerGeneratorPrompt
from dataflow.prompts.reasoning.diy import DiyAnswerGeneratorPrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC

import pandas as pd
from typing import Union, List, Tuple
import re

import os

@prompt_restrict(
    MathAnswerGeneratorPrompt,
    GeneralAnswerGeneratorPrompt,
    DiyAnswerGeneratorPrompt
)
@OPERATOR_REGISTRY.register()
class VQAReasoningAnswerGenerator(OperatorABC):
    '''
    Answer Generator is a class that generates answers for given questions.
    '''
    def __init__(self,
                llm_serving: LLMServingABC,
                prompt_template: Union[MathAnswerGeneratorPrompt, GeneralAnswerGeneratorPrompt, DiyAnswerGeneratorPrompt, DIYPromptABC] = MathAnswerGeneratorPrompt,
                skip_text_only: bool=False,
                input_image_default_basedir = "./"
                ):
        
        self.logger = get_logger()
        
        if prompt_template is None:
            prompt_template = MathAnswerGeneratorPrompt()
        self.prompts = prompt_template
        self.llm_serving = llm_serving
        self.skip_text_only = skip_text_only
        self.input_image_default_basedir = input_image_default_basedir
        
    @staticmethod
    def get_desc(lang: str = "zh"):
        if lang == "zh":
            return (
                "该算子用于为给定问题生成答案,调用大语言模型进行推理。\n"
                "输入参数:\n"
                "- llm_serving:LLM服务实例,用于生成答案\n"
                "- prompt_template:提示模板对象,用于构建生成提示词\n"
                "输出参数:\n"
                "- output_key:生成的答案字段,默认'generated_cot'"
            )
        elif lang == "en":
            return (
                "This operator generates answers for given questions using LLMs for reasoning. \n"
                "Input Parameters:\n"
                "- llm_serving: LLM serving instance for answer generation\n"
                "- prompt_template: Prompt template object for constructing generation prompts\n"
                "Output Parameters:\n"
                "- output_key: Generated answer field, default 'generated_cot'"
            )
        else:
            return "AnswerGenerator produces answers for questions using large language models."

    def _validate_dataframe(self, dataframe: pd.DataFrame):
        required_keys = [self.input_key, self.input_image_basedir_key, self.input_caption_key, self.input_skip_key]
        missing = [k for k in required_keys if k!=None and k not in dataframe.columns]
        if missing:
            raise ValueError(f"Missing required column(s): {missing}")

    def _prepare_vlm_inputs(self, dataframe) -> Tuple[List[str], List[List[str]], List[List[str]], List[int], List[int]]:
        """
        Parses prompts for image markdown, extracts paths and text segments, 
        and structures them into interleaved lists for the VLM server.
        
        返回:
            user_prompts: List[str] (所有的问题)
            list_of_image_paths: List[List[str]] (所有请求的绝对路径列表)
            list_of_text_segments: List[List[str]] (所有图像标签)
            vqa_ids: List[int] (含有图片的问题编号)
            unskipped_ids: List[int] (未跳过的问题编号,这里跳过是指保留已经回答的答案,不重新回答)
        """
        list_of_image_paths: List[List[str]] = []
        list_of_text_segments: List[List[str]] = []
        user_prompts:List[str] = []
        vqa_ids = []
        
        # Markdown 图片正则匹配: ![label](path)
        markdown_pattern = re.compile(r"!\[(.*?)\]\((.*?)\)")
        
        questions = dataframe[self.input_key].tolist()
        
        unskipped_ids = []
        
        for index, question in enumerate(questions):
            
            # 1. 确定 Base Directory (图像的根目录)
            base_dir = self.input_image_default_basedir
            if self.input_image_basedir_key in dataframe.columns:
                row_base_dir = dataframe.loc[index, self.input_image_basedir_key]
                if row_base_dir:
                    base_dir = row_base_dir
            
            # 2. 准备该请求的结构
            current_paths: List[str] = []
            current_segments: List[str] = []
            current_user_prompt: str = ""
            
            last_end = 0
            
            # 查找所有图片匹配项
            matches = list(markdown_pattern.finditer(question))
            
            # 3. 处理纯文本或交错文本/图像
            if (not matches):
                if (not self.skip_text_only):
                    # 纯文本提示:直接构建提示并作为唯一的文本片段
                    if self.input_skip_key != None and self.input_skip_key in dataframe.columns:
                        if dataframe.loc[index, self.input_skip_key]:
                            continue
                    final_prompt_text = self.prompts.build_prompt(question)
                    # 如果caption key存在,添加caption信息
                    if self.input_caption_key != None and self.input_caption_key in dataframe.columns:
                        captions = dataframe.loc[index, self.input_caption_key]
                        if captions and isinstance(captions, list):
                            for cap_i, caption in enumerate(captions):
                                final_prompt_text += f"\n Description of image {cap_i+1}: {caption}"
                    user_prompts.append(final_prompt_text)
                    list_of_image_paths.append([])
                    list_of_text_segments.append([])
                    unskipped_ids.append(index)
                continue
            
            vqa_complete = True
            # 4. 遍历匹配项,提取交错的文本片段和图像路径
            for match in matches:
                leading_text = question[last_end:match.start()].strip()
                if leading_text:
                    current_user_prompt += leading_text
                
                label = match.group(1).strip()
                path = match.group(2).strip()
                
                current_segments.append(label)
                
                # 4c. 记录绝对路径 (原始逻辑)
                full_path = os.path.join(base_dir, path)

                # 检查路径是否存在
                if not os.path.isfile(full_path):
                    self.logger.warning(f"Image file not found: {full_path} (from question index {index})")
                    vqa_complete = False
                    break
                
                current_paths.append(full_path)
                
                last_end = match.end()

            trailing_text = question[last_end:].strip()
            if trailing_text:
                current_user_prompt += trailing_text
                
            # 如果caption key存在,添加caption信息
            if self.input_caption_key != None and self.input_caption_key in dataframe.columns:
                captions = dataframe.loc[index, self.input_caption_key]
                if captions and isinstance(captions, list):
                    for cap_i, caption in enumerate(captions):
                        current_user_prompt += f"\n Description of image {cap_i+1}: {caption}"
                
            # 5. 存储该请求的结果
            if vqa_complete:
                vqa_ids.append(index)
                if self.input_skip_key != None and self.input_skip_key in dataframe.columns:
                    if dataframe.loc[index, self.input_skip_key]:
                        continue
                list_of_image_paths.append(current_paths)
                list_of_text_segments.append(current_segments)
                user_prompts.append(self.prompts.build_prompt(current_user_prompt))
                unskipped_ids.append(index)
                

        return user_prompts, list_of_image_paths, list_of_text_segments, vqa_ids, unskipped_ids

    def run(
        self, 
        storage,
        input_key:str = "instruction", 
        output_key:str = "generated_cot",
        input_caption_key: str | None = None,
        input_skip_key: str | None = None,
        input_image_basedir_key = "image_basedir",
        ):
        '''
        Runs the answer generation process, reading from the input file and saving results to output.
        '''
        self.input_key, self.output_key = input_key, output_key
        self.input_caption_key = input_caption_key
        self.input_skip_key = input_skip_key
        self.input_image_basedir_key = input_image_basedir_key
        dataframe = storage.read("dataframe")
        self._validate_dataframe(dataframe)
        
        # 1. 准备 VLM 输入: 解析 Markdown 并获取路径和文本片段
        user_prompts, list_of_image_paths, list_of_image_labels, vqa_ids, unskipped_ids = self._prepare_vlm_inputs(dataframe)
        
        # 2. 获取 System Prompt (假设它存储在 self.prompts 对象中)
        # 如果 self.prompts 没有 system_prompt 属性,则使用默认值。
        system_prompt = "You are an intelligent chatbot good at college subjects."
        
        answers = self.llm_serving.generate_from_input_multi_images(
            list_of_image_paths=list_of_image_paths,
            list_of_image_labels=list_of_image_labels, 
            system_prompt=system_prompt,
            user_prompts=user_prompts
        )

        if self.skip_text_only:
            # 只写入vqa_ids指明的行
            dataframe = dataframe.loc[vqa_ids].copy() # 注意,这里不重置索引
        
        for idx, ans in zip(unskipped_ids, answers):
            dataframe.at[idx, self.output_key] = ans
                           
        output_file = storage.write(dataframe)
        self.logger.info(f"Results saved to {output_file}")

        return [output_key]