DataFlow-VQA / pipelines /generate_cot.py
aaron1141's picture
initial hf spaces demo
e783436
import os
import sys
import pandas as pd
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from dataflow.operators.core_text import PandasOperator
from operators.bench_evaluate import BenchDatasetEvaluatorQuestion
from operators.vqa_answer_generator import VQAReasoningAnswerGenerator
from dataflow.serving import APILLMServing_request, APIVLMServing_openai, LocalVLMServing_vllm
from dataflow.utils.storage import FileStorage
from dataflow.operators.reasoning import (
ReasoningAnswerGenerator,
ReasoningAnswerGroundTruthFilter
)
from dataflow.prompts.reasoning.math import MathAnswerGeneratorPrompt
from dataflow.operators.core_text import GeneralFilter
from dataflow import get_logger
from dataflow.pipeline import PipelineABC
from typing import Iterable
import re
import argparse
import shutil
def make_remove_think_fn(input_key, output_key):
pattern = re.compile(r'<think>.*?</think>', flags=re.DOTALL | re.IGNORECASE)
def fn(df):
df = df.copy()
if input_key in df.columns:
def clean_text(t):
if pd.isna(t):
return t
if "</think>" not in t:
return t.strip()
s = "<think>" + str(t)
return pattern.sub("", s).strip()
df[output_key] = df[input_key].apply(clean_text)
return df
return fn
class RejectSamplingPipeline(PipelineABC):
def __init__(self, first_entry_file_name, answer_api_url, judge_api_url, answer_model, judge_model,
answer_api_key_env="DF_API_KEY", judge_api_key_env="DF_API_KEY",
max_retries=5, max_workers=100):
super().__init__()
self.storage = FileStorage(
first_entry_file_name=first_entry_file_name,
cache_path="./cot_cache",
file_name_prefix="reject_sampling",
cache_type="jsonl",
)
self.max_retries = max_retries
self.logger = get_logger()
self.llm_answer_serving = APIVLMServing_openai(
api_url=answer_api_url,
model_name=answer_model,
key_name_of_api_key=answer_api_key_env,
max_workers=max_workers,
timeout=600.0,
max_tokens=8192,
temperature=0.7,
)
self.llm_serving = APILLMServing_request(
api_url=f"{judge_api_url}/chat/completions",
model_name=judge_model,
key_name_of_api_key=judge_api_key_env,
max_workers=max_workers,
read_timeout=300.0
)
# Difficulty filter (keep items where accuracy <= 1.0)
self.difficulty_filter = GeneralFilter(
filter_rules=[lambda df: df['accuracy'] <= 1.0]
)
# LLM answer generation
self.answer_generator = VQAReasoningAnswerGenerator(
llm_serving=self.llm_answer_serving,
prompt_template=MathAnswerGeneratorPrompt(),
skip_text_only=False,
)
self.think_cleaner = PandasOperator(process_fn=[ make_remove_think_fn(input_key="generated_cot", output_key="llm_short_answer") ])
self.noop = PandasOperator(process_fn=[ lambda df: df ])
# LLM verification
self.answer_groundtruth_filter = BenchDatasetEvaluatorQuestion(
compare_method="semantic",
llm_serving=self.llm_serving,
prompt_template=None, # using default prompt
eval_result_path="./cot_cache/eval_results.jsonl",
support_subquestions=True,
skip_true=True
)
def forward(self):
self.noop.run(storage = self.storage.step(), output_key="answer_match_result") # for pipeline compilation, do nothing
for i in range(self.max_retries):
input_skip_key="answer_match_result" if i > 0 else None
# Generate answers (skip items already answered correctly)
self.answer_generator.run(
storage = self.storage.step(),
input_key = "question",
output_key = "generated_cot",
input_skip_key=input_skip_key,
input_image_basedir_key="image_basedir",
)
self.think_cleaner.run(storage = self.storage.step(), output_key="llm_short_answer")
self.answer_groundtruth_filter.run(
storage=self.storage.step(),
input_test_answer_key="llm_short_answer",
input_gt_answer_key="answer",
input_question_key="question",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="CoT Generation Pipeline with Reject Sampling")
parser.add_argument("--input_file", type=str, required=True, help="Path to the input JSONL file (curated_vqa.jsonl)")
parser.add_argument("--max_retries", type=int, default=5, help="Maximum number of reject sampling rounds")
parser.add_argument("--answer_api_url", type=str, default="https://api.xxx.com/v1", help="Url where you serve your qwen model (e.g. via vllm)")
parser.add_argument("--judge_api_url", type=str, default="https://api.openai.com/v1", help="Base URL of the OpenAI-compatible API for answer verification (e.g. https://api.openai.com/v1)")
parser.add_argument("--answer_model", type=str, default="qwen3-vl-235b-thinking", help="Model to use for answer generation")
parser.add_argument("--judge_model", type=str, default="gpt-5-mini", help="Model to use for answer verification")
parser.add_argument("--answer_api_key_env", type=str, default="DF_API_KEY", help="Environment variable name holding the API key for the answer model")
parser.add_argument("--judge_api_key_env", type=str, default="DF_API_KEY", help="Environment variable name holding the API key for the judge model")
parser.add_argument("--max_workers", type=int, default=100, help="Number of parallel API workers")
args = parser.parse_args()
model = RejectSamplingPipeline(
args.input_file,
answer_api_url=args.answer_api_url,
judge_api_url=args.judge_api_url,
answer_model=args.answer_model,
judge_model=args.judge_model,
answer_api_key_env=args.answer_api_key_env,
judge_api_key_env=args.judge_api_key_env,
max_retries=args.max_retries,
max_workers=args.max_workers,
)
model.compile()
model.forward()
# Find the latest reject_sampling cache step file
cache_files = os.listdir("./cot_cache")
step_files = [f for f in cache_files if re.match(r"reject_sampling_step\d+\.jsonl", f)]
step_numbers = [int(re.findall(r"reject_sampling_step(\d+)\.jsonl", f)[0]) for f in step_files]
max_step = max(step_numbers)
max_step_file = f"./cot_cache/reject_sampling_step{max_step}.jsonl"
# Copy output alongside input_file so relative image paths remain valid
output_dir = os.path.dirname(args.input_file)
output_file = os.path.join(output_dir, "curated_vqa_with_cot.jsonl")
shutil.copy(max_step_file, output_file)
print(f"Curated data with cot saved to: {output_file}")