import os import json import random import shutil import pandas as pd from tqdm import tqdm from collections import Counter IMAGES_DIR = r"../train2014" QUESTIONS_PATH = r"../v2_OpenEnded_mscoco_train2014_questions.json" ANNOTATIONS_PATH = r"../v2_mscoco_train2014_annotations.json" OUTPUT_DIR = "./gen_vqa_v2" os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True) print("Loading VQA v2 data...") with open(QUESTIONS_PATH, "r") as f: questions = json.load(f)["questions"] with open(ANNOTATIONS_PATH, "r") as f: annotations = json.load(f)["annotations"] qid_to_ann = {ann["question_id"]: ann for ann in annotations} print("Merging questions and answers...") merged_data = [] answer_counter = Counter() EXCLUDED_ANSWERS = ['yes', 'no', 'unknown', 'none', 'n/a', 'cant tell', 'not sure'] AMBIGUOUS_QUESTIONS = ['what is in the image', 'what is this', 'what is that', 'what do you see'] for q in tqdm(questions, total=len(questions)): ann = qid_to_ann.get(q["question_id"]) if not ann: continue answers = [a["answer"] for a in ann["answers"] if a["answer"].strip()] if not answers: continue main_answer = max(set(answers), key=answers.count) main_answer = main_answer.lower().strip() question_text = q["question"].lower().strip() if main_answer in EXCLUDED_ANSWERS: continue if any(ambig in question_text for ambig in AMBIGUOUS_QUESTIONS): continue if len(main_answer.split()) <= 5 and len(main_answer) <= 30: merged_data.append({ "image_id": q["image_id"], "question_id": q["question_id"], "question": q["question"], "answer": main_answer }) answer_counter[main_answer] += 1 print(f"Total valid Q-A pairs (after filtering): {len(merged_data)}") MIN_ANSWER_FREQ = 20 frequent_answers = {ans for ans, count in answer_counter.items() if count >= MIN_ANSWER_FREQ} filtered_data = [item for item in merged_data if item["answer"] in frequent_answers] print(f"After frequency filtering (min_freq={MIN_ANSWER_FREQ}): {len(filtered_data)} pairs") MAX_SAMPLES_PER_ANSWER = 600 answer_samples = {} for item in filtered_data: ans = item["answer"] if ans not in answer_samples: answer_samples[ans] = [] if len(answer_samples[ans]) < MAX_SAMPLES_PER_ANSWER: answer_samples[ans].append(item) balanced_data = [] for samples in answer_samples.values(): balanced_data.extend(samples) random.shuffle(balanced_data) print(f"After balancing: {len(balanced_data)} pairs with {len(answer_samples)} unique answers") print("Copying selected images and saving data...") final_data = [] for item in tqdm(balanced_data): img_name = f"COCO_train2014_{item['image_id']:012d}.jpg" src_path = os.path.join(IMAGES_DIR, img_name) dst_path = os.path.join(OUTPUT_DIR, "images", img_name) if os.path.exists(src_path): shutil.copy(src_path, dst_path) item["image_path"] = f"images/{img_name}" final_data.append(item) print(f"Final dataset: {len(final_data)} pairs") with open(os.path.join(OUTPUT_DIR, "qa_pairs.json"), "w") as f: json.dump(final_data, f, indent=2, ensure_ascii=False) pd.DataFrame(final_data).to_csv(os.path.join(OUTPUT_DIR, "metadata.csv"), index=False) print("Data preparation complete.")