DingZhenDojoCat's picture
Add files using upload-large-folder tool
7ed0fb5 verified
import json
import random
import os
from openai import OpenAI
from tqdm import tqdm
import concurrent.futures
from typing import List, Dict, Optional
from datetime import datetime
from threading import Lock
import time
from prompt import R1_SYS_PROMPT
# Initialize the client
client = OpenAI(
api_key=os.environ.get("SL_KEY", "YOUR_SILCONFLOW_KEY"),
base_url="https://api.siliconflow.cn/v1",
)
# Create a lock for thread-safe file writing
file_lock = Lock()
def format_query(qa_dict: Dict, v2=False) -> str:
query = "Answer the question according to scene description.\n\n"
query += qa_dict["description"]
query += f"\nQuestion:\n{qa_dict['q']}"
if v2:
query += "\nInstructions:\n"
query += "1. Carefully analyze the scene description\n"
query += "2. Provide your reasoning if necessary\n"
query += "3. For the final answer, start a new line with '**The answer is: **' followed by your answer\n"
return query
def write_to_jsonl(result: Dict, filename: str):
"""Thread-safe function to write a result to JSONL file"""
with file_lock:
with open(filename, 'a') as f:
f.write(json.dumps(result) + '\n')
def query_r1(qa_pair: Dict, output_file: str, model: str = "deepseek-ai/DeepSeek-R1", v2=False) -> Optional[Dict]:
query = format_query(qa_pair, v2=v2)
try:
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": R1_SYS_PROMPT},
{"role": "user", "content": query}],
stream=False,
max_tokens=4096
)
result = {
**qa_pair,
"r1_response": response.choices[0].message.content,
"timestamp": datetime.now().isoformat()
}
# Write result immediately
write_to_jsonl(result, output_file)
time.sleep(4)
return result
except Exception as e:
print(f"Error processing query: {e}")
error_result = {
**qa_pair,
"error": str(e),
"timestamp": datetime.now().isoformat()
}
write_to_jsonl(error_result, f"errors_{output_file}")
time.sleep(10)
return None
def process_qa_pairs_parallel(qa_pairs: List[Dict], output_file: str, max_workers: int = 10) -> List[Dict]:
successful_count = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# Create futures for all qa_pairs
futures = [executor.submit(query_r1, qa_pair, output_file, v2="v2" in output_file) for qa_pair in qa_pairs]
# Process results as they complete with progress bar
results = []
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
try:
result = future.result()
if result is not None:
results.append(result)
successful_count += 1
except Exception as e:
print(f"Failed to process query: {e}")
return results
if __name__ == "__main__":
# Load and shuffle QA pairs
random.seed(1234)
qa_pairs = json.load(open("/home/lilei/Visual-R1/data/clever_counting_problems_clevr_cogent_v1.0_trainA.json"))
random.shuffle(qa_pairs)
qa_pairs = qa_pairs[:10000]
# Create output filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f"r1_results_clevr_cogent_v1.0_trainA_v2.jsonl"
finished = set()
with open(output_file, 'r') as f:
for line in f:
ins = json.loads(line)
key = ins["img_filename"] + "-" + ins["q"] + "-" + str(ins["a"])
finished.add(key)
qa_pairs = [ins for ins in qa_pairs if ins["img_filename"] + "-" + ins["q"] + "-" + str(ins["a"]) not in finished]
print("Finished: ", len(finished))
print("Remaining: ", len(qa_pairs))
# Process QA pairs in parallel
r1_results = process_qa_pairs_parallel(qa_pairs, output_file)
# Print final statistics
print(f"Successfully processed {len(r1_results)} out of {len(qa_pairs)} queries")
print(f"Results saved to {output_file}")
print(f"Any errors were saved to errors_{output_file}")