tools / utils /parquet /MathVision_ptj.py
Adinosaur's picture
Upload folder using huggingface_hub
1c980b1 verified
import argparse
import json
import os
import logging
import sys
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from ast import literal_eval
import time
from typing import Tuple
import pandas as pd
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
stream=sys.stdout
)
def process_row(args):
"""处理单行数据(线程安全)"""
index, row, file_stem = args
try:
# ================== 基础字段处理 ==================
media_path = "./" + (Path("data") / file_stem / f"{index}.jpg").as_posix()
description = row.get("subject", "") # 从subject字段获取描述
# ================== 动态确定问题类型 ==================
# 解析options字段(支持字符串格式的列表)
options_data = row.get("options", [])
if isinstance(options_data, str):
try:
options_data = literal_eval(options_data) # 尝试解析字符串为列表
except:
options_data = [] # 解析失败时视为空列表
elif not isinstance(options_data, list):
options_data = list(options_data) # 强制转换为列表(处理数组/Series)
formatted_question_type = "free-form" if len(options_data) == 0 else "multi-choice"
# ================== 不同类型处理 ==================
options = []
answer = []
if formatted_question_type == "multi-choice":
# 生成标准选项结构
options = [
{"id": chr(65 + i), "text": str(text).strip()}
for i, text in enumerate(options_data)
]
# 匹配答案选项
answer_text = str(row.get("answer", "")).strip()
for option in options:
if option["text"] == answer_text:
answer = [option["id"]]
break
else: # free-form类型处理
raw_answer = row.get("answer", "")
# 处理空值和特殊格式
if pd.isna(raw_answer) or raw_answer in ["nan", "None"]:
answer = [""]
else:
# 统一转换为字符串并清理格式
cleaned_answer = " ".join(str(raw_answer).strip().split())
answer = [cleaned_answer]
# ================== 构建结果对象 ==================
return {
"index": index,
"media_type": "image",
"media_paths": media_path,
"description": description,
"task_type": "Vision-Question-Answer",
"question": [row.get('question', '')],
"question_type": formatted_question_type,
"options": options,
"annotations": [],
"answer": answer,
"source": "MathVision",
"domain": "Math"
}
except Exception as e:
logging.error(f"处理行 {index} 时出错: {str(e)}")
return None
def process_single_parquet(parquet_path: Path, output_root: Path) -> Tuple[int, int]:
"""处理单个Parquet文件"""
start_time = time.time()
file_stem = parquet_path.stem
output_dir = output_root
output_json = output_dir / f"{file_stem}.json"
success_count = 0
error_count = 0
results = []
try:
df = pd.read_parquet(parquet_path)
total_rows = len(df)
logging.info(f"\n{'='*40}\nProcessing: {parquet_path.name}")
logging.info(f"Output Directory: {output_dir.name}")
# 创建线程池
with ThreadPoolExecutor(max_workers = min(os.cpu_count() * 2, 32)) as executor:
task_args = [(idx, row, file_stem) for idx, row in df.iterrows()]
futures = [executor.submit(process_row, args) for args in task_args]
for future in futures:
result = future.result()
if result:
results.append(result)
success_count += 1
else:
error_count += 1
# 写入JSON文件
with open(output_json, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
# 生成报告
process_time = time.time() - start_time
logging.info(
f"Processed: {success_count}/{total_rows} | "
f"Errors: {error_count} | "
f"Time: {process_time:.2f}s"
)
return success_count, error_count
except Exception as e:
logging.error(f"处理文件失败: {str(e)}")
return 0, total_rows
def batch_process_parquets(input_dir: Path, output_root: Path):
"""批量处理目录下所有Parquet文件"""
input_path = Path(input_dir)
output_root = Path(output_root)
if not input_path.exists():
raise FileNotFoundError(f"输入目录不存在: {input_path}")
parquet_files = list(input_path.glob("*.parquet"))
if not parquet_files:
logging.warning("未找到Parquet文件")
return
total_stats = {'success': 0, 'errors': 0}
for parquet_file in parquet_files:
success, errors = process_single_parquet(parquet_file, output_root)
total_stats['success'] += success
total_stats['errors'] += errors
logging.info(f"\n{'='*40}\n批量处理完成")
logging.info(f"处理文件总数: {len(parquet_files)}")
logging.info(f"总成功条目: {total_stats['success']}")
logging.info(f"总失败条目: {total_stats['errors']}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='批量处理Parquet文件转JSON')
parser.add_argument('-i', '--input', required=True, help='输入目录路径')
parser.add_argument('-o', '--output', required=True, help='输出根目录路径')
args = parser.parse_args()
try:
start_time = time.time()
batch_process_parquets(
input_dir=args.input,
output_root=args.output
)
logging.info(f"\n总耗时: {time.time()-start_time:.2f}s")
except Exception as e:
logging.error(f"程序异常终止: {str(e)}")
sys.exit(1)