tools / utils /json /emb_ai_jtp.py
Adinosaur's picture
Upload folder using huggingface_hub
1c980b1 verified
import argparse
import json
import base64
import sys
import time
import logging
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Tuple, List
from PIL import Image
import io
# 配置日志格式
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
stream=sys.stdout
)
def process_json_element(element: dict,
index: int,
output_dir: Path,
overwrite: bool,
output_format: str) -> Tuple[int, str]: # 新增格式参数
"""处理单个JSON数组元素(支持格式选择)"""
try:
# 参数校验并统一转为小写
output_format = output_format.lower() # 确保后续判断统一使用小写
if output_format not in ['jpg', 'png']:
raise ValueError(f"不支持的格式: {output_format}")
# 动态生成路径参数(文件扩展名保持小写)
file_ext = output_format
# 映射PIL所需的格式名称(JPEG非JPG)
img_format = 'JPEG' if output_format == 'jpg' else output_format.upper()
save_args = {'quality': 95} if output_format == 'jpg' else {'compress_level': 6}
output_path = output_dir / f"{index}.{file_ext}"
# 跳过已存在文件
if not overwrite and output_path.exists():
return (index, "skipped")
# 数据校验
if not isinstance(element, dict):
raise ValueError("数组元素不是字典类型")
if "image" not in element:
raise KeyError("缺少'image'字段")
# 图像解码
image_bytes = base64.b64decode(element["image"])
with Image.open(io.BytesIO(image_bytes)) as img:
# 公共处理:CMYK转换
if img.mode == 'CMYK':
img = img.convert('RGB')
# 格式专用处理
if output_format == 'jpg':
# 处理需要转换的透明模式
if img.mode == 'RGBA':
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1])
img = background
elif img.mode in ['P', 'PA']: # 调色板模式处理
img = img.convert('RGBA')
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1])
img = background
elif img.mode == 'LA': # 灰度+透明度
img = img.convert('L')
# 最终模式校验
if img.mode not in ['RGB', 'L']:
img = img.convert('RGB')
# 保存图像(使用PIL兼容的格式名)
img.save(output_path, img_format, **save_args)
return (index, "success")
except Exception as e:
return (index, f"error: {str(e)}")
def process_single_json(json_path: Path,
output_root: Path,
threads: int = 4,
overwrite: bool = False,
output_format: str = 'jpg') -> Tuple[str, int, int]: # 新增格式参数
"""处理单个JSON文件(支持并发)"""
start_time = time.time()
file_stem = json_path.stem
output_dir = output_root / file_stem
output_dir.mkdir(parents=True, exist_ok=True)
error_log = []
success_count = 0
skipped_count = 0
try:
with open(json_path, "r") as f:
json_data = json.load(f)
if not isinstance(json_data, list):
raise ValueError("JSON根元素不是数组类型")
with ThreadPoolExecutor(max_workers=threads) as executor:
futures = [
executor.submit(
process_json_element,
element,
idx,
output_dir,
overwrite,
output_format # 传递格式参数
)
for idx, element in enumerate(json_data)
]
for future in as_completed(futures):
idx, status = future.result()
if status == "success":
success_count += 1
elif status == "skipped":
skipped_count += 1
elif status.startswith("error"):
error_log.append(f"元素{idx}错误: {status[6:]}")
process_time = time.time() - start_time
logging.info(
f"文件 {file_stem} 处理完成 | "
f"成功: {success_count} | "
f"跳过: {skipped_count} | "
f"错误: {len(error_log)} | "
f"耗时: {process_time:.2f}s"
)
if error_log:
(output_dir / "process_errors.log").write_text("\n".join(error_log))
return (json_path.name, success_count, len(error_log))
except Exception as e:
logging.error(f"文件处理失败: {str(e)}")
return (json_path.name, 0, 1)
def batch_process_jsons(input_dir: Path,
output_root: Path,
threads: int = 4,
overwrite: bool = False,
output_format: str = 'jpg'): # 新增格式参数
"""批量处理JSON文件"""
input_path = Path(input_dir)
output_root = Path(output_root)
if not input_path.exists():
raise FileNotFoundError(f"输入目录不存在: {input_path}")
json_files = list(input_path.glob("*.json"))
if not json_files:
logging.warning("未找到JSON文件")
return
total_stats = {"success": 0, "errors": 0}
with ThreadPoolExecutor(max_workers=threads) as executor:
futures = {
executor.submit(
process_single_json,
json_file,
output_root,
threads,
overwrite,
output_format # 传递格式参数
): json_file for json_file in json_files
}
for future in as_completed(futures):
try:
filename, success, errors = future.result()
total_stats["success"] += success
total_stats["errors"] += errors
except Exception as e:
total_stats["errors"] += 1
logging.error(f"处理异常: {str(e)}")
logging.info(f"\n{'='*40}")
logging.info(f"处理完成文件总数: {len(json_files)}")
logging.info(f"总成功图片数: {total_stats['success']}")
logging.info(f"总错误数: {total_stats['errors']}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="处理JSON文件中的Base64图片(支持格式选择)")
parser.add_argument("-i", "--input", required=True, help="输入目录路径")
parser.add_argument("-o", "--output", required=True, help="输出目录路径")
parser.add_argument("--threads", type=int, default=4, help="并发线程数(默认4)")
parser.add_argument("--overwrite", action="store_true", help="覆盖已存在的文件")
parser.add_argument("--format", choices=['png', 'jpg'], default='jpg',
help="输出图片格式(png/jpg,默认jpg)")
args = parser.parse_args()
try:
start = time.time()
batch_process_jsons(
input_dir=args.input,
output_root=args.output,
threads=args.threads,
overwrite=args.overwrite,
output_format=args.format # 传递格式参数
)
logging.info(f"\n总耗时: {time.time()-start:.2f}秒")
except Exception as e:
logging.error(f"程序异常终止: {str(e)}")
sys.exit(1)