import argparse import json import os import logging import sys from pathlib import Path from concurrent.futures import ThreadPoolExecutor from ast import literal_eval import time from typing import Tuple import pandas as pd # 配置日志 logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", stream=sys.stdout ) def process_row(args): """处理单行数据(线程安全)""" index, row, file_stem = args try: # ================== 基础字段处理 ================== media_path = "./" + (Path("data") / file_stem / f"{index}.jpg").as_posix() raw_question_type = row.get("question_type", "").strip().lower() # ================== 题目类型转换 ================== # 根据数据集规范转换question_type格式 if raw_question_type == "multi-choice": formatted_question_type = "multi-choice" elif raw_question_type == "free-form": formatted_question_type = "free-form" else: logging.warning(f"未知的题目类型: {raw_question_type}") formatted_question_type = raw_question_type # ================== 公共字段处理 ================== description = str(row.get("problem_version", "")).strip() # ================== 初始化关键字段 ================== options = [] answer = [] question_text = "" # ================== 分类型处理 ================== if formatted_question_type == "multi-choice": # 提取问题文本和选项 question_for_eval = str(row.get("question_for_eval", "")) # 分割问题描述和选项部分 if "Choices:" in question_for_eval: q_parts = question_for_eval.split("Choices:", 1) question_part = q_parts[0].strip() choices_part = q_parts[1].strip() else: question_part = question_for_eval.strip() choices_part = "" # 清理问题文本格式(合并换行和多余空格) question_text = " ".join(question_part.replace("\n", " ").split()) # 解析选项 for line in choices_part.split('\n'): line = line.strip() if not line: continue # 支持 A: 或 A. 两种分隔符 if ':' in line: id_text = line.split(':', 1) elif '.' in line: id_text = line.split('.', 1) else: continue # 跳过无法解析的行 if len(id_text) == 2: option_id = id_text[0].strip().upper() option_text = " ".join(id_text[1].replace("\n", " ").split()) options.append({"id": option_id, "text": option_text}) # 处理答案(直接使用原始答案字母) raw_answer = str(row.get("answer", "")).strip().upper() if raw_answer: answer = [raw_answer] elif formatted_question_type == "free-form": # 提取问题文本 question_text = " ".join(str(row.get("query", "")).replace("\n", " ").split()) # 处理答案(统一转为字符串) raw_answer = row.get("answer", "") if pd.isna(raw_answer): answer = [""] else: cleaned_answer = " ".join(str(raw_answer).strip().split()) answer = [cleaned_answer] # ================== 构建结果对象 ================== return { "index": index, "media_type": "image", "media_paths": media_path, "description": description, "task_type": "Vision-Question-Answer", "question": [question_text], "question_type": formatted_question_type, "options": options, "annotations": [], "answer": answer, "source": "MathVerse", "domain": "Math" } except Exception as e: logging.error(f"处理行 {index} 时出错: {str(e)}") return None def process_single_parquet(parquet_path: Path, output_root: Path) -> Tuple[int, int]: """处理单个Parquet文件""" start_time = time.time() file_stem = parquet_path.stem output_dir = output_root output_json = output_dir / f"{file_stem}.json" success_count = 0 error_count = 0 results = [] try: df = pd.read_parquet(parquet_path) total_rows = len(df) logging.info(f"\n{'='*40}\nProcessing: {parquet_path.name}") logging.info(f"Output directory: {output_dir}") # 创建线程池 with ThreadPoolExecutor(max_workers=os.cpu_count() * 2) as executor: task_args = [(idx, row, file_stem) for idx, row in df.iterrows()] futures = [executor.submit(process_row, args) for args in task_args] for future in futures: result = future.result() if result: results.append(result) success_count += 1 else: error_count += 1 # 写入JSON文件 with open(output_json, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) # 生成报告 process_time = time.time() - start_time logging.info( f"Processed: {success_count}/{total_rows} | " f"Errors: {error_count} | " f"Time: {process_time:.2f}s" ) return success_count, error_count except Exception as e: logging.error(f"处理文件失败: {str(e)}") return 0, total_rows def batch_process_parquets(input_dir: Path, output_root: Path): """批量处理目录下所有Parquet文件""" input_path = Path(input_dir) output_root = Path(output_root) if not input_path.exists(): raise FileNotFoundError(f"输入目录不存在: {input_path}") parquet_files = list(input_path.glob("*.parquet")) if not parquet_files: logging.warning("未找到Parquet文件") return total_stats = {'success': 0, 'errors': 0} for parquet_file in parquet_files: success, errors = process_single_parquet(parquet_file, output_root) total_stats['success'] += success total_stats['errors'] += errors logging.info(f"\n{'='*40}\n批量处理完成") logging.info(f"处理文件总数: {len(parquet_files)}") logging.info(f"总成功条目: {total_stats['success']}") logging.info(f"总失败条目: {total_stats['errors']}") if __name__ == "__main__": parser = argparse.ArgumentParser(description='批量处理Parquet文件转JSON') parser.add_argument('-i', '--input', required=True, help='输入目录路径') parser.add_argument('-o', '--output', required=True, help='输出根目录路径') args = parser.parse_args() try: start_time = time.time() batch_process_parquets( input_dir=args.input, output_root=args.output ) logging.info(f"\n总耗时: {time.time()-start_time:.2f}s") except Exception as e: logging.error(f"程序异常终止: {str(e)}") sys.exit(1)