File size: 4,174 Bytes
fbc2aa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import os
import random
import shutil
from pathlib import Path

IMG_EXTS = [".png", ".jpg", ".jpeg", ".bmp", ".webp"]

def is_image(p: Path) -> bool:
    return p.suffix.lower() in IMG_EXTS

def safe_mkdir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def copy_or_link(src: Path, dst: Path, mode: str):
    safe_mkdir(dst.parent)
    if dst.exists():
        dst.unlink()
    if mode == "copy":
        shutil.copy2(src, dst)
    elif mode == "symlink":
        os.symlink(src, dst)
    elif mode == "hardlink":
        os.link(src, dst)
    else:
        raise ValueError(f"unknown mode: {mode}")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--src", default="/home/jeong/data_v3", help="source images dir (contains 1.png~69.png)")
    ap.add_argument("--labels", default="/home/jeong/data_v3/labels", help="source labels dir (contains 1.txt~69.txt)")
    ap.add_argument("--out", default="/home/jeong/jetank_ws/data_jenga_v3_yolo", help="output dataset root")
    ap.add_argument("--val_n", type=int, default=8, help="number of validation images (recommended 6~9 for 69 imgs)")
    ap.add_argument("--seed", type=int, default=42, help="random seed for split")
    ap.add_argument("--mode", choices=["copy", "symlink", "hardlink"], default="copy",
                    help="how to place files into output (copy safest; symlink fastest)")
    ap.add_argument("--class_names", default="jenga",
                    help="comma-separated class names for data.yaml (e.g., 'jenga' or 'jenga,jenga_stack')")
    args = ap.parse_args()

    src_dir = Path(args.src)
    lab_dir = Path(args.labels)
    out_dir = Path(args.out)

    if not src_dir.exists():
        raise FileNotFoundError(f"src not found: {src_dir}")
    if not lab_dir.exists():
        raise FileNotFoundError(f"labels not found: {lab_dir}")

    # collect images
    images = sorted([p for p in src_dir.iterdir() if p.is_file() and is_image(p)])
    if not images:
        raise RuntimeError(f"no images found in: {src_dir}")

    # match label by stem
    pairs = []
    missing_labels = []
    for img in images:
        lbl = lab_dir / f"{img.stem}.txt"
        if not lbl.exists():
            missing_labels.append(img.name)
            continue
        pairs.append((img, lbl))

    if missing_labels:
        print("[WARN] missing label files for images:")
        for n in missing_labels:
            print("  -", n)
        print("[WARN] those will be skipped.")

    if len(pairs) < 2:
        raise RuntimeError("not enough (image,label) pairs to split.")

    # split
    random.seed(args.seed)
    random.shuffle(pairs)

    val_n = args.val_n
    val_n = max(1, min(val_n, len(pairs) - 1))  # ensure at least 1 train
    val_pairs = pairs[:val_n]
    train_pairs = pairs[val_n:]

    # output dirs
    img_train = out_dir / "images" / "train"
    img_val   = out_dir / "images" / "val"
    lab_train = out_dir / "labels" / "train"
    lab_val   = out_dir / "labels" / "val"
    for d in [img_train, img_val, lab_train, lab_val]:
        safe_mkdir(d)

    # copy/link
    for img, lbl in train_pairs:
        copy_or_link(img, img_train / img.name, args.mode)
        copy_or_link(lbl, lab_train / lbl.name, args.mode)

    for img, lbl in val_pairs:
        copy_or_link(img, img_val / img.name, args.mode)
        copy_or_link(lbl, lab_val / lbl.name, args.mode)

    # write data.yaml (Ultralytics format)
    class_names = [c.strip() for c in args.class_names.split(",") if c.strip()]
    if not class_names:
        class_names = ["jenga"]

    data_yaml = out_dir / "data.yaml"
    yaml_txt = (
        f"path: {out_dir}\n"
        f"train: images/train\n"
        f"val: images/val\n"
        f"names:\n"
        + "".join([f"  {i}: {name}\n" for i, name in enumerate(class_names)])
    )
    data_yaml.write_text(yaml_txt, encoding="utf-8")

    print("✅ Done.")
    print(f"- total pairs: {len(pairs)}")
    print(f"- train: {len(train_pairs)}")
    print(f"- val: {len(val_pairs)}")
    print(f"- out: {out_dir}")
    print(f"- data.yaml: {data_yaml}")

if __name__ == "__main__":
    main()