import os import sys import pandas as pd sys.path.append(os.path.dirname(os.path.dirname(__file__))) from dataflow.operators.core_text import PandasOperator from operators.bench_evaluate import BenchDatasetEvaluatorQuestion from operators.vqa_answer_generator import VQAReasoningAnswerGenerator from dataflow.serving import APILLMServing_request, APIVLMServing_openai, LocalVLMServing_vllm from dataflow.utils.storage import FileStorage from dataflow.operators.reasoning import ( ReasoningAnswerGenerator, ReasoningAnswerGroundTruthFilter ) from dataflow.prompts.reasoning.math import MathAnswerGeneratorPrompt from dataflow.operators.core_text import GeneralFilter from dataflow import get_logger from dataflow.pipeline import PipelineABC from typing import Iterable import re import argparse import shutil def make_remove_think_fn(input_key, output_key): pattern = re.compile(r'.*?', flags=re.DOTALL | re.IGNORECASE) def fn(df): df = df.copy() if input_key in df.columns: def clean_text(t): if pd.isna(t): return t if "" not in t: return t.strip() s = "" + str(t) return pattern.sub("", s).strip() df[output_key] = df[input_key].apply(clean_text) return df return fn class RejectSamplingPipeline(PipelineABC): def __init__(self, first_entry_file_name, answer_api_url, judge_api_url, answer_model, judge_model, answer_api_key_env="DF_API_KEY", judge_api_key_env="DF_API_KEY", max_retries=5, max_workers=100): super().__init__() self.storage = FileStorage( first_entry_file_name=first_entry_file_name, cache_path="./cot_cache", file_name_prefix="reject_sampling", cache_type="jsonl", ) self.max_retries = max_retries self.logger = get_logger() self.llm_answer_serving = APIVLMServing_openai( api_url=answer_api_url, model_name=answer_model, key_name_of_api_key=answer_api_key_env, max_workers=max_workers, timeout=600.0, max_tokens=8192, temperature=0.7, ) self.llm_serving = APILLMServing_request( api_url=f"{judge_api_url}/chat/completions", model_name=judge_model, key_name_of_api_key=judge_api_key_env, max_workers=max_workers, read_timeout=300.0 ) # Difficulty filter (keep items where accuracy <= 1.0) self.difficulty_filter = GeneralFilter( filter_rules=[lambda df: df['accuracy'] <= 1.0] ) # LLM answer generation self.answer_generator = VQAReasoningAnswerGenerator( llm_serving=self.llm_answer_serving, prompt_template=MathAnswerGeneratorPrompt(), skip_text_only=False, ) self.think_cleaner = PandasOperator(process_fn=[ make_remove_think_fn(input_key="generated_cot", output_key="llm_short_answer") ]) self.noop = PandasOperator(process_fn=[ lambda df: df ]) # LLM verification self.answer_groundtruth_filter = BenchDatasetEvaluatorQuestion( compare_method="semantic", llm_serving=self.llm_serving, prompt_template=None, # using default prompt eval_result_path="./cot_cache/eval_results.jsonl", support_subquestions=True, skip_true=True ) def forward(self): self.noop.run(storage = self.storage.step(), output_key="answer_match_result") # for pipeline compilation, do nothing for i in range(self.max_retries): input_skip_key="answer_match_result" if i > 0 else None # Generate answers (skip items already answered correctly) self.answer_generator.run( storage = self.storage.step(), input_key = "question", output_key = "generated_cot", input_skip_key=input_skip_key, input_image_basedir_key="image_basedir", ) self.think_cleaner.run(storage = self.storage.step(), output_key="llm_short_answer") self.answer_groundtruth_filter.run( storage=self.storage.step(), input_test_answer_key="llm_short_answer", input_gt_answer_key="answer", input_question_key="question", ) if __name__ == "__main__": parser = argparse.ArgumentParser(description="CoT Generation Pipeline with Reject Sampling") parser.add_argument("--input_file", type=str, required=True, help="Path to the input JSONL file (curated_vqa.jsonl)") parser.add_argument("--max_retries", type=int, default=5, help="Maximum number of reject sampling rounds") parser.add_argument("--answer_api_url", type=str, default="https://api.xxx.com/v1", help="Url where you serve your qwen model (e.g. via vllm)") parser.add_argument("--judge_api_url", type=str, default="https://api.openai.com/v1", help="Base URL of the OpenAI-compatible API for answer verification (e.g. https://api.openai.com/v1)") parser.add_argument("--answer_model", type=str, default="qwen3-vl-235b-thinking", help="Model to use for answer generation") parser.add_argument("--judge_model", type=str, default="gpt-5-mini", help="Model to use for answer verification") parser.add_argument("--answer_api_key_env", type=str, default="DF_API_KEY", help="Environment variable name holding the API key for the answer model") parser.add_argument("--judge_api_key_env", type=str, default="DF_API_KEY", help="Environment variable name holding the API key for the judge model") parser.add_argument("--max_workers", type=int, default=100, help="Number of parallel API workers") args = parser.parse_args() model = RejectSamplingPipeline( args.input_file, answer_api_url=args.answer_api_url, judge_api_url=args.judge_api_url, answer_model=args.answer_model, judge_model=args.judge_model, answer_api_key_env=args.answer_api_key_env, judge_api_key_env=args.judge_api_key_env, max_retries=args.max_retries, max_workers=args.max_workers, ) model.compile() model.forward() # Find the latest reject_sampling cache step file cache_files = os.listdir("./cot_cache") step_files = [f for f in cache_files if re.match(r"reject_sampling_step\d+\.jsonl", f)] step_numbers = [int(re.findall(r"reject_sampling_step(\d+)\.jsonl", f)[0]) for f in step_files] max_step = max(step_numbers) max_step_file = f"./cot_cache/reject_sampling_step{max_step}.jsonl" # Copy output alongside input_file so relative image paths remain valid output_dir = os.path.dirname(args.input_file) output_file = os.path.join(output_dir, "curated_vqa_with_cot.jsonl") shutil.copy(max_step_file, output_file) print(f"Curated data with cot saved to: {output_file}")