Spaces:
Sleeping
Sleeping
File size: 2,553 Bytes
e783436 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
@OPERATOR_REGISTRY.register()
class AddMissingBlankOperator(OperatorABC):
def __init__(
self,
llm_serving: LLMServingABC,
prompt_template,
):
self.logger = get_logger()
self.llm_serving = llm_serving
self.prompt_template = prompt_template
if prompt_template is None:
raise ValueError("prompt_template cannot be None")
def run(
self,
storage: DataFlowStorage,
output_key: str = "question",
**input_keys
):
self.storage: DataFlowStorage = storage
self.output_key = output_key
self.logger.info("Running AddMissingBlankOperator...")
self.input_keys = input_keys
need_fields = set(input_keys.keys())
# Load the raw dataframe from the input file
dataframe = storage.read('dataframe')
self.logger.info(f"Loading, number of rows: {len(dataframe)}")
llm_inputs = []
# Only process rows where type == "fill-in"
if 'type' not in dataframe.columns:
self.logger.warning("No 'type' column found, skipping LLM generation.")
generated_outputs = []
else:
mask = dataframe['type'] == "Fill-in"
indices = dataframe.index[mask].tolist()
if not indices:
self.logger.info("No rows with type=='Fill-in' to process.")
generated_outputs = []
else:
for idx in indices:
row = dataframe.loc[idx]
key_dict = {key: row[input_keys[key]] for key in need_fields}
prompt_text = self.prompt_template.build_prompt(need_fields, **key_dict)
llm_inputs.append(prompt_text)
self.logger.info(f"Prepared {len(llm_inputs)} prompts for LLM generation.")
generated_outputs = self.llm_serving.generate_from_input(llm_inputs)
# write generated outputs back only to the selected rows (preserve other rows as None)
for idx, gen_output in zip(indices, generated_outputs):
if gen_output != "ORIGINAL":
dataframe.at[idx, output_key] = gen_output
output_file = self.storage.write(dataframe)
return output_key |