tools / utils /parquet /merge_jp.py
Adinosaur's picture
Upload folder using huggingface_hub
1c980b1 verified
import argparse
import json
import shutil
from pathlib import Path
import pyarrow.parquet as pq
def merge_chemqa_dataset(dataset_dir: str, data_dir: str, output_name: str = "ChemQA") -> None:
"""
自动合并dataset_dir中所有Parquet文件对应的数据部分
参数:
dataset_dir : 存放.parquet文件的目录
data_dir : 存放图片文件夹的根目录(JSON文件位于其父目录)
output_name : 合并后的数据集名称
"""
# 转换为Path对象
dataset_path = Path(dataset_dir)
data_path = Path(data_dir)
# 获取所有需要合并的Parquet文件(按文件名排序)
parquet_files = sorted(dataset_path.glob("*.parquet"))
if not parquet_files:
raise FileNotFoundError("未找到任何Parquet文件")
# 创建目标图片目录
chemqa_img_dir = data_path / output_name
chemqa_img_dir.mkdir(exist_ok=True)
merged_data = []
global_offset = 0 # 全局索引偏移量
for pq_file in parquet_files:
part_name = pq_file.stem # 获取不带扩展名的文件名,如test-00000-of-00001
part_img_dir = data_path / part_name
# 验证图片文件夹存在
if not part_img_dir.exists():
raise FileNotFoundError(f"图片目录不存在: {part_img_dir}")
# 从Parquet获取行数
with pq.ParquetFile(pq_file) as pf:
num_rows = pf.metadata.num_rows
# 复制并重命名图片
print(f"合并 {part_name}{num_rows} 张图片...")
for idx in range(num_rows):
src = part_img_dir / f"{idx}.jpg" # 假设图片格式为jpg
dst = chemqa_img_dir / f"{global_offset + idx}.jpg"
shutil.copy2(src, dst)
# 从data_dir的父目录加载对应JSON文件
json_file = data_path.parent / f"{part_name}.json"
if not json_file.exists():
raise FileNotFoundError(f"JSON文件不存在: {json_file}")
with open(json_file, 'r', encoding='utf-8') as f:
part_data = json.load(f)
# 更新索引和路径
for item in part_data:
original_idx = item["index"]
item["index"] = global_offset + original_idx
item["media_paths"] = f"./data/{output_name}/{global_offset + original_idx}.jpg"
merged_data.extend(part_data)
global_offset += num_rows
# 保存合并后的JSON到data_dir父目录
output_json = data_path.parent / f"{output_name}.json"
with open(output_json, 'w', encoding='utf-8') as f:
json.dump(merged_data, f, indent=2, ensure_ascii=False)
print(f"\n合并完成!共处理 {len(parquet_files)} 个部分")
print(f"生成数据集: {output_name}")
print(f"- 图片总数: {global_offset} 张(位于 {chemqa_img_dir})")
print(f"- JSON条目: {len(merged_data)} 条(位于 {output_json})")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="合并问答数据集")
parser.add_argument('-i', dest="dataset_dir", required=True, help="Parquet文件所在的目录路径")
parser.add_argument('-ip', dest="data_dir", required=True, help="图片文件夹根目录路径(JSON在其父目录)")
parser.add_argument('-o', "--output_name", default="ChemQA", help="输出数据集名称(默认为ChemQA)")
args = parser.parse_args()
merge_chemqa_dataset(
dataset_dir=args.dataset_dir,
data_dir=args.data_dir,
output_name=args.output_name
)