Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import pandas as pd | |
| sys.path.append(os.path.dirname(os.path.dirname(__file__))) | |
| from dataflow.operators.core_text import PandasOperator | |
| from operators.bench_evaluate import BenchDatasetEvaluatorQuestion | |
| from operators.vqa_answer_generator import VQAReasoningAnswerGenerator | |
| from dataflow.serving import APILLMServing_request, APIVLMServing_openai, LocalVLMServing_vllm | |
| from dataflow.utils.storage import FileStorage | |
| from dataflow.operators.reasoning import ( | |
| ReasoningAnswerGenerator, | |
| ReasoningAnswerGroundTruthFilter | |
| ) | |
| from dataflow.prompts.reasoning.math import MathAnswerGeneratorPrompt | |
| from dataflow.operators.core_text import GeneralFilter | |
| from dataflow import get_logger | |
| from dataflow.pipeline import PipelineABC | |
| from typing import Iterable | |
| import re | |
| import argparse | |
| import shutil | |
| def make_remove_think_fn(input_key, output_key): | |
| pattern = re.compile(r'<think>.*?</think>', flags=re.DOTALL | re.IGNORECASE) | |
| def fn(df): | |
| df = df.copy() | |
| if input_key in df.columns: | |
| def clean_text(t): | |
| if pd.isna(t): | |
| return t | |
| if "</think>" not in t: | |
| return t.strip() | |
| s = "<think>" + str(t) | |
| return pattern.sub("", s).strip() | |
| df[output_key] = df[input_key].apply(clean_text) | |
| return df | |
| return fn | |
| class RejectSamplingPipeline(PipelineABC): | |
| def __init__(self, first_entry_file_name, answer_api_url, judge_api_url, answer_model, judge_model, | |
| answer_api_key_env="DF_API_KEY", judge_api_key_env="DF_API_KEY", | |
| max_retries=5, max_workers=100): | |
| super().__init__() | |
| self.storage = FileStorage( | |
| first_entry_file_name=first_entry_file_name, | |
| cache_path="./cot_cache", | |
| file_name_prefix="reject_sampling", | |
| cache_type="jsonl", | |
| ) | |
| self.max_retries = max_retries | |
| self.logger = get_logger() | |
| self.llm_answer_serving = APIVLMServing_openai( | |
| api_url=answer_api_url, | |
| model_name=answer_model, | |
| key_name_of_api_key=answer_api_key_env, | |
| max_workers=max_workers, | |
| timeout=600.0, | |
| max_tokens=8192, | |
| temperature=0.7, | |
| ) | |
| self.llm_serving = APILLMServing_request( | |
| api_url=f"{judge_api_url}/chat/completions", | |
| model_name=judge_model, | |
| key_name_of_api_key=judge_api_key_env, | |
| max_workers=max_workers, | |
| read_timeout=300.0 | |
| ) | |
| # Difficulty filter (keep items where accuracy <= 1.0) | |
| self.difficulty_filter = GeneralFilter( | |
| filter_rules=[lambda df: df['accuracy'] <= 1.0] | |
| ) | |
| # LLM answer generation | |
| self.answer_generator = VQAReasoningAnswerGenerator( | |
| llm_serving=self.llm_answer_serving, | |
| prompt_template=MathAnswerGeneratorPrompt(), | |
| skip_text_only=False, | |
| ) | |
| self.think_cleaner = PandasOperator(process_fn=[ make_remove_think_fn(input_key="generated_cot", output_key="llm_short_answer") ]) | |
| self.noop = PandasOperator(process_fn=[ lambda df: df ]) | |
| # LLM verification | |
| self.answer_groundtruth_filter = BenchDatasetEvaluatorQuestion( | |
| compare_method="semantic", | |
| llm_serving=self.llm_serving, | |
| prompt_template=None, # using default prompt | |
| eval_result_path="./cot_cache/eval_results.jsonl", | |
| support_subquestions=True, | |
| skip_true=True | |
| ) | |
| def forward(self): | |
| self.noop.run(storage = self.storage.step(), output_key="answer_match_result") # for pipeline compilation, do nothing | |
| for i in range(self.max_retries): | |
| input_skip_key="answer_match_result" if i > 0 else None | |
| # Generate answers (skip items already answered correctly) | |
| self.answer_generator.run( | |
| storage = self.storage.step(), | |
| input_key = "question", | |
| output_key = "generated_cot", | |
| input_skip_key=input_skip_key, | |
| input_image_basedir_key="image_basedir", | |
| ) | |
| self.think_cleaner.run(storage = self.storage.step(), output_key="llm_short_answer") | |
| self.answer_groundtruth_filter.run( | |
| storage=self.storage.step(), | |
| input_test_answer_key="llm_short_answer", | |
| input_gt_answer_key="answer", | |
| input_question_key="question", | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="CoT Generation Pipeline with Reject Sampling") | |
| parser.add_argument("--input_file", type=str, required=True, help="Path to the input JSONL file (curated_vqa.jsonl)") | |
| parser.add_argument("--max_retries", type=int, default=5, help="Maximum number of reject sampling rounds") | |
| parser.add_argument("--answer_api_url", type=str, default="https://api.xxx.com/v1", help="Url where you serve your qwen model (e.g. via vllm)") | |
| parser.add_argument("--judge_api_url", type=str, default="https://api.openai.com/v1", help="Base URL of the OpenAI-compatible API for answer verification (e.g. https://api.openai.com/v1)") | |
| parser.add_argument("--answer_model", type=str, default="qwen3-vl-235b-thinking", help="Model to use for answer generation") | |
| parser.add_argument("--judge_model", type=str, default="gpt-5-mini", help="Model to use for answer verification") | |
| parser.add_argument("--answer_api_key_env", type=str, default="DF_API_KEY", help="Environment variable name holding the API key for the answer model") | |
| parser.add_argument("--judge_api_key_env", type=str, default="DF_API_KEY", help="Environment variable name holding the API key for the judge model") | |
| parser.add_argument("--max_workers", type=int, default=100, help="Number of parallel API workers") | |
| args = parser.parse_args() | |
| model = RejectSamplingPipeline( | |
| args.input_file, | |
| answer_api_url=args.answer_api_url, | |
| judge_api_url=args.judge_api_url, | |
| answer_model=args.answer_model, | |
| judge_model=args.judge_model, | |
| answer_api_key_env=args.answer_api_key_env, | |
| judge_api_key_env=args.judge_api_key_env, | |
| max_retries=args.max_retries, | |
| max_workers=args.max_workers, | |
| ) | |
| model.compile() | |
| model.forward() | |
| # Find the latest reject_sampling cache step file | |
| cache_files = os.listdir("./cot_cache") | |
| step_files = [f for f in cache_files if re.match(r"reject_sampling_step\d+\.jsonl", f)] | |
| step_numbers = [int(re.findall(r"reject_sampling_step(\d+)\.jsonl", f)[0]) for f in step_files] | |
| max_step = max(step_numbers) | |
| max_step_file = f"./cot_cache/reject_sampling_step{max_step}.jsonl" | |
| # Copy output alongside input_file so relative image paths remain valid | |
| output_dir = os.path.dirname(args.input_file) | |
| output_file = os.path.join(output_dir, "curated_vqa_with_cot.jsonl") | |
| shutil.copy(max_step_file, output_file) | |
| print(f"Curated data with cot saved to: {output_file}") | |