File size: 5,146 Bytes
7803bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python
# coding: utf-8

"""
在 `data_root/` 下已经有 `train/` 和 `validation/` 两个文件夹时:
分别在这两个文件夹内生成对应的 `metadata.jsonl`,不复制任何图片。

`metadata.jsonl` 每行格式:
  {"file_name": "subdir/000026831.jpg", "caption": "..."}

其中 `file_name` 是相对当前 split 目录(train/ 或 validation/)的路径。
"""

import argparse
import json
import os
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
from pathlib import Path
from typing import Optional, Tuple

from tqdm import tqdm


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Generate per-split metadata.jsonl for imagefolder (no copy)")
    parser.add_argument(
        "--data_root",
        type=str,
        default="/gemini/space/hsd/project/dataset/cc3m-wds",
        help="数据根目录(必须包含 train/ 和 validation/)",
    )
    parser.add_argument(
        "--jsonl_name",
        type=str,
        default="metadata.jsonl",
        help="每个 split 下生成的 jsonl 文件名(默认 metadata.jsonl)",
    )
    parser.add_argument(
        "--use_txt_caption",
        action="store_true",
        default=True,
        help="优先读取同名 .txt 作为 caption(默认开启),否则回落到 .json",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=32,
        help="线程数(I/O 密集型建议 8~64 之间按机器调整)",
    )
    parser.add_argument(
        "--max_images",
        type=int,
        default=None,
        help="每个 split 最多处理多少张图片(None 表示全部,调试可用)",
    )
    return parser.parse_args()


def read_caption_from_txt(txt_path: Path) -> Optional[str]:
    if not txt_path.exists():
        return None
    try:
        with txt_path.open("r", encoding="utf-8") as f:
            caption = f.read().strip()
        return caption or None
    except Exception:
        return None


def read_caption_from_json(json_path: Path) -> Optional[str]:
    if not json_path.exists():
        return None
    try:
        with json_path.open("r", encoding="utf-8") as f:
            data = json.load(f)
        for key in ["caption", "text", "description"]:
            if key in data and isinstance(data[key], str) and data[key].strip():
                return data[key].strip()
    except Exception:
        return None
    return None


def main() -> None:
    args = parse_args()

    data_root = Path(args.data_root).resolve()
    if not data_root.exists():
        raise FileNotFoundError(f"数据根目录不存在:{data_root}")

    splits = [("train", data_root / "train"), ("validation", data_root / "validation")]
    for split_name, split_dir in splits:
        if not split_dir.exists():
            raise FileNotFoundError(f"缺少目录:{split_dir}(需要 train/ 和 validation/)")

    def iter_images(split_dir: Path):
        for root, _dirs, files in os.walk(split_dir):
            for name in files:
                if name.lower().endswith((".jpg", ".jpeg", ".png")):
                    yield Path(root) / name

    def process_one(img_path: Path, split_dir: Path) -> Optional[Tuple[str, str]]:
        txt_path = img_path.with_suffix(".txt")
        json_path = img_path.with_suffix(".json")

        caption = None
        if args.use_txt_caption:
            caption = read_caption_from_txt(txt_path)
            if caption is None:
                caption = read_caption_from_json(json_path)
        else:
            caption = read_caption_from_json(json_path)
            if caption is None:
                caption = read_caption_from_txt(txt_path)

        if caption is None:
            return None

        rel = img_path.relative_to(split_dir)
        return str(rel).replace(os.sep, "/"), caption

    for split_name, split_dir in splits:
        jsonl_path = split_dir / args.jsonl_name

        img_iter = iter_images(split_dir)
        if args.max_images is not None:
            img_iter = islice(img_iter, args.max_images)

        # tqdm 需要可迭代对象,这里不预先收集列表以节省内存
        # 进度条显示 processed 数量(total 可能未知)
        def _task_iter():
            for p in img_iter:
                yield p

        written = 0
        with jsonl_path.open("w", encoding="utf-8") as f, ThreadPoolExecutor(max_workers=args.num_workers) as ex:
            # executor.map 保持输入顺序;tqdm 显示处理进度
            for result in tqdm(
                ex.map(lambda p: process_one(p, split_dir), _task_iter()),
                desc=f"[{split_name}] Processing",
            ):
                if result is None:
                    continue
                file_name, caption = result
                f.write(json.dumps({"file_name": file_name, "caption": caption}, ensure_ascii=False) + "\n")
                written += 1

        print(f"{split_name}: 写入 {written} 条 -> {jsonl_path}")


if __name__ == "__main__":
    main()

# nohup python cc3m_render.py > cc3m_render.log 2>&1 &