File size: 5,146 Bytes
7803bdf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | #!/usr/bin/env python
# coding: utf-8
"""
在 `data_root/` 下已经有 `train/` 和 `validation/` 两个文件夹时:
分别在这两个文件夹内生成对应的 `metadata.jsonl`,不复制任何图片。
`metadata.jsonl` 每行格式:
{"file_name": "subdir/000026831.jpg", "caption": "..."}
其中 `file_name` 是相对当前 split 目录(train/ 或 validation/)的路径。
"""
import argparse
import json
import os
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
from pathlib import Path
from typing import Optional, Tuple
from tqdm import tqdm
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate per-split metadata.jsonl for imagefolder (no copy)")
parser.add_argument(
"--data_root",
type=str,
default="/gemini/space/hsd/project/dataset/cc3m-wds",
help="数据根目录(必须包含 train/ 和 validation/)",
)
parser.add_argument(
"--jsonl_name",
type=str,
default="metadata.jsonl",
help="每个 split 下生成的 jsonl 文件名(默认 metadata.jsonl)",
)
parser.add_argument(
"--use_txt_caption",
action="store_true",
default=True,
help="优先读取同名 .txt 作为 caption(默认开启),否则回落到 .json",
)
parser.add_argument(
"--num_workers",
type=int,
default=32,
help="线程数(I/O 密集型建议 8~64 之间按机器调整)",
)
parser.add_argument(
"--max_images",
type=int,
default=None,
help="每个 split 最多处理多少张图片(None 表示全部,调试可用)",
)
return parser.parse_args()
def read_caption_from_txt(txt_path: Path) -> Optional[str]:
if not txt_path.exists():
return None
try:
with txt_path.open("r", encoding="utf-8") as f:
caption = f.read().strip()
return caption or None
except Exception:
return None
def read_caption_from_json(json_path: Path) -> Optional[str]:
if not json_path.exists():
return None
try:
with json_path.open("r", encoding="utf-8") as f:
data = json.load(f)
for key in ["caption", "text", "description"]:
if key in data and isinstance(data[key], str) and data[key].strip():
return data[key].strip()
except Exception:
return None
return None
def main() -> None:
args = parse_args()
data_root = Path(args.data_root).resolve()
if not data_root.exists():
raise FileNotFoundError(f"数据根目录不存在:{data_root}")
splits = [("train", data_root / "train"), ("validation", data_root / "validation")]
for split_name, split_dir in splits:
if not split_dir.exists():
raise FileNotFoundError(f"缺少目录:{split_dir}(需要 train/ 和 validation/)")
def iter_images(split_dir: Path):
for root, _dirs, files in os.walk(split_dir):
for name in files:
if name.lower().endswith((".jpg", ".jpeg", ".png")):
yield Path(root) / name
def process_one(img_path: Path, split_dir: Path) -> Optional[Tuple[str, str]]:
txt_path = img_path.with_suffix(".txt")
json_path = img_path.with_suffix(".json")
caption = None
if args.use_txt_caption:
caption = read_caption_from_txt(txt_path)
if caption is None:
caption = read_caption_from_json(json_path)
else:
caption = read_caption_from_json(json_path)
if caption is None:
caption = read_caption_from_txt(txt_path)
if caption is None:
return None
rel = img_path.relative_to(split_dir)
return str(rel).replace(os.sep, "/"), caption
for split_name, split_dir in splits:
jsonl_path = split_dir / args.jsonl_name
img_iter = iter_images(split_dir)
if args.max_images is not None:
img_iter = islice(img_iter, args.max_images)
# tqdm 需要可迭代对象,这里不预先收集列表以节省内存
# 进度条显示 processed 数量(total 可能未知)
def _task_iter():
for p in img_iter:
yield p
written = 0
with jsonl_path.open("w", encoding="utf-8") as f, ThreadPoolExecutor(max_workers=args.num_workers) as ex:
# executor.map 保持输入顺序;tqdm 显示处理进度
for result in tqdm(
ex.map(lambda p: process_one(p, split_dir), _task_iter()),
desc=f"[{split_name}] Processing",
):
if result is None:
continue
file_name, caption = result
f.write(json.dumps({"file_name": file_name, "caption": caption}, ensure_ascii=False) + "\n")
written += 1
print(f"{split_name}: 写入 {written} 条 -> {jsonl_path}")
if __name__ == "__main__":
main()
# nohup python cc3m_render.py > cc3m_render.log 2>&1 &
|