File size: 2,553 Bytes
e783436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger

from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC

@OPERATOR_REGISTRY.register()
class AddMissingBlankOperator(OperatorABC):
    def __init__(
            self,
            llm_serving: LLMServingABC, 
            prompt_template,
        ):
        self.logger = get_logger()
        self.llm_serving = llm_serving
        self.prompt_template = prompt_template
        if prompt_template is None:
            raise ValueError("prompt_template cannot be None")


    def run(
            self, 
            storage: DataFlowStorage,
            output_key: str = "question",
            **input_keys
        ):
        self.storage: DataFlowStorage = storage
        self.output_key = output_key
        self.logger.info("Running AddMissingBlankOperator...")
        self.input_keys = input_keys

        need_fields = set(input_keys.keys())

    # Load the raw dataframe from the input file
        dataframe = storage.read('dataframe')
        self.logger.info(f"Loading, number of rows: {len(dataframe)}")
        llm_inputs = []


        # Only process rows where type == "fill-in"
        if 'type' not in dataframe.columns:
            self.logger.warning("No 'type' column found, skipping LLM generation.")
            generated_outputs = []
        else:
            mask = dataframe['type'] == "Fill-in"
            indices = dataframe.index[mask].tolist()
            if not indices:
                self.logger.info("No rows with type=='Fill-in' to process.")
                generated_outputs = []
            else:
                for idx in indices:
                    row = dataframe.loc[idx]
                    key_dict = {key: row[input_keys[key]] for key in need_fields}
                    prompt_text = self.prompt_template.build_prompt(need_fields, **key_dict)
                    llm_inputs.append(prompt_text)
                self.logger.info(f"Prepared {len(llm_inputs)} prompts for LLM generation.")
                generated_outputs = self.llm_serving.generate_from_input(llm_inputs)
            # write generated outputs back only to the selected rows (preserve other rows as None)
            for idx, gen_output in zip(indices, generated_outputs):
                if gen_output != "ORIGINAL":
                    dataframe.at[idx, output_key] = gen_output

        output_file = self.storage.write(dataframe)
        return output_key