tools / utils /parquet /MathVerse_ptj.py
Adinosaur's picture
Upload folder using huggingface_hub
1c980b1 verified
import argparse
import json
import os
import logging
import sys
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from ast import literal_eval
import time
from typing import Tuple
import pandas as pd
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
stream=sys.stdout
)
def process_row(args):
"""处理单行数据(线程安全)"""
index, row, file_stem = args
try:
# ================== 基础字段处理 ==================
media_path = "./" + (Path("data") / file_stem / f"{index}.jpg").as_posix()
raw_question_type = row.get("question_type", "").strip().lower()
# ================== 题目类型转换 ==================
# 根据数据集规范转换question_type格式
if raw_question_type == "multi-choice":
formatted_question_type = "multi-choice"
elif raw_question_type == "free-form":
formatted_question_type = "free-form"
else:
logging.warning(f"未知的题目类型: {raw_question_type}")
formatted_question_type = raw_question_type
# ================== 公共字段处理 ==================
description = str(row.get("problem_version", "")).strip()
# ================== 初始化关键字段 ==================
options = []
answer = []
question_text = ""
# ================== 分类型处理 ==================
if formatted_question_type == "multi-choice":
# 提取问题文本和选项
question_for_eval = str(row.get("question_for_eval", ""))
# 分割问题描述和选项部分
if "Choices:" in question_for_eval:
q_parts = question_for_eval.split("Choices:", 1)
question_part = q_parts[0].strip()
choices_part = q_parts[1].strip()
else:
question_part = question_for_eval.strip()
choices_part = ""
# 清理问题文本格式(合并换行和多余空格)
question_text = " ".join(question_part.replace("\n", " ").split())
# 解析选项
for line in choices_part.split('\n'):
line = line.strip()
if not line:
continue
# 支持 A: 或 A. 两种分隔符
if ':' in line:
id_text = line.split(':', 1)
elif '.' in line:
id_text = line.split('.', 1)
else:
continue # 跳过无法解析的行
if len(id_text) == 2:
option_id = id_text[0].strip().upper()
option_text = " ".join(id_text[1].replace("\n", " ").split())
options.append({"id": option_id, "text": option_text})
# 处理答案(直接使用原始答案字母)
raw_answer = str(row.get("answer", "")).strip().upper()
if raw_answer:
answer = [raw_answer]
elif formatted_question_type == "free-form":
# 提取问题文本
question_text = " ".join(str(row.get("query", "")).replace("\n", " ").split())
# 处理答案(统一转为字符串)
raw_answer = row.get("answer", "")
if pd.isna(raw_answer):
answer = [""]
else:
cleaned_answer = " ".join(str(raw_answer).strip().split())
answer = [cleaned_answer]
# ================== 构建结果对象 ==================
return {
"index": index,
"media_type": "image",
"media_paths": media_path,
"description": description,
"task_type": "Vision-Question-Answer",
"question": [question_text],
"question_type": formatted_question_type,
"options": options,
"annotations": [],
"answer": answer,
"source": "MathVerse",
"domain": "Math"
}
except Exception as e:
logging.error(f"处理行 {index} 时出错: {str(e)}")
return None
def process_single_parquet(parquet_path: Path, output_root: Path) -> Tuple[int, int]:
"""处理单个Parquet文件"""
start_time = time.time()
file_stem = parquet_path.stem
output_dir = output_root
output_json = output_dir / f"{file_stem}.json"
success_count = 0
error_count = 0
results = []
try:
df = pd.read_parquet(parquet_path)
total_rows = len(df)
logging.info(f"\n{'='*40}\nProcessing: {parquet_path.name}")
logging.info(f"Output directory: {output_dir}")
# 创建线程池
with ThreadPoolExecutor(max_workers=os.cpu_count() * 2) as executor:
task_args = [(idx, row, file_stem) for idx, row in df.iterrows()]
futures = [executor.submit(process_row, args) for args in task_args]
for future in futures:
result = future.result()
if result:
results.append(result)
success_count += 1
else:
error_count += 1
# 写入JSON文件
with open(output_json, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
# 生成报告
process_time = time.time() - start_time
logging.info(
f"Processed: {success_count}/{total_rows} | "
f"Errors: {error_count} | "
f"Time: {process_time:.2f}s"
)
return success_count, error_count
except Exception as e:
logging.error(f"处理文件失败: {str(e)}")
return 0, total_rows
def batch_process_parquets(input_dir: Path, output_root: Path):
"""批量处理目录下所有Parquet文件"""
input_path = Path(input_dir)
output_root = Path(output_root)
if not input_path.exists():
raise FileNotFoundError(f"输入目录不存在: {input_path}")
parquet_files = list(input_path.glob("*.parquet"))
if not parquet_files:
logging.warning("未找到Parquet文件")
return
total_stats = {'success': 0, 'errors': 0}
for parquet_file in parquet_files:
success, errors = process_single_parquet(parquet_file, output_root)
total_stats['success'] += success
total_stats['errors'] += errors
logging.info(f"\n{'='*40}\n批量处理完成")
logging.info(f"处理文件总数: {len(parquet_files)}")
logging.info(f"总成功条目: {total_stats['success']}")
logging.info(f"总失败条目: {total_stats['errors']}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='批量处理Parquet文件转JSON')
parser.add_argument('-i', '--input', required=True, help='输入目录路径')
parser.add_argument('-o', '--output', required=True, help='输出根目录路径')
args = parser.parse_args()
try:
start_time = time.time()
batch_process_parquets(
input_dir=args.input,
output_root=args.output
)
logging.info(f"\n总耗时: {time.time()-start_time:.2f}s")
except Exception as e:
logging.error(f"程序异常终止: {str(e)}")
sys.exit(1)