#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import os import random import shutil from pathlib import Path IMG_EXTS = [".png", ".jpg", ".jpeg", ".bmp", ".webp"] def is_image(p: Path) -> bool: return p.suffix.lower() in IMG_EXTS def safe_mkdir(p: Path): p.mkdir(parents=True, exist_ok=True) def copy_or_link(src: Path, dst: Path, mode: str): safe_mkdir(dst.parent) if dst.exists(): dst.unlink() if mode == "copy": shutil.copy2(src, dst) elif mode == "symlink": os.symlink(src, dst) elif mode == "hardlink": os.link(src, dst) else: raise ValueError(f"unknown mode: {mode}") def main(): ap = argparse.ArgumentParser() ap.add_argument("--src", default="/home/jeong/data_v3", help="source images dir (contains 1.png~69.png)") ap.add_argument("--labels", default="/home/jeong/data_v3/labels", help="source labels dir (contains 1.txt~69.txt)") ap.add_argument("--out", default="/home/jeong/jetank_ws/data_jenga_v3_yolo", help="output dataset root") ap.add_argument("--val_n", type=int, default=8, help="number of validation images (recommended 6~9 for 69 imgs)") ap.add_argument("--seed", type=int, default=42, help="random seed for split") ap.add_argument("--mode", choices=["copy", "symlink", "hardlink"], default="copy", help="how to place files into output (copy safest; symlink fastest)") ap.add_argument("--class_names", default="jenga", help="comma-separated class names for data.yaml (e.g., 'jenga' or 'jenga,jenga_stack')") args = ap.parse_args() src_dir = Path(args.src) lab_dir = Path(args.labels) out_dir = Path(args.out) if not src_dir.exists(): raise FileNotFoundError(f"src not found: {src_dir}") if not lab_dir.exists(): raise FileNotFoundError(f"labels not found: {lab_dir}") # collect images images = sorted([p for p in src_dir.iterdir() if p.is_file() and is_image(p)]) if not images: raise RuntimeError(f"no images found in: {src_dir}") # match label by stem pairs = [] missing_labels = [] for img in images: lbl = lab_dir / f"{img.stem}.txt" if not lbl.exists(): missing_labels.append(img.name) continue pairs.append((img, lbl)) if missing_labels: print("[WARN] missing label files for images:") for n in missing_labels: print(" -", n) print("[WARN] those will be skipped.") if len(pairs) < 2: raise RuntimeError("not enough (image,label) pairs to split.") # split random.seed(args.seed) random.shuffle(pairs) val_n = args.val_n val_n = max(1, min(val_n, len(pairs) - 1)) # ensure at least 1 train val_pairs = pairs[:val_n] train_pairs = pairs[val_n:] # output dirs img_train = out_dir / "images" / "train" img_val = out_dir / "images" / "val" lab_train = out_dir / "labels" / "train" lab_val = out_dir / "labels" / "val" for d in [img_train, img_val, lab_train, lab_val]: safe_mkdir(d) # copy/link for img, lbl in train_pairs: copy_or_link(img, img_train / img.name, args.mode) copy_or_link(lbl, lab_train / lbl.name, args.mode) for img, lbl in val_pairs: copy_or_link(img, img_val / img.name, args.mode) copy_or_link(lbl, lab_val / lbl.name, args.mode) # write data.yaml (Ultralytics format) class_names = [c.strip() for c in args.class_names.split(",") if c.strip()] if not class_names: class_names = ["jenga"] data_yaml = out_dir / "data.yaml" yaml_txt = ( f"path: {out_dir}\n" f"train: images/train\n" f"val: images/val\n" f"names:\n" + "".join([f" {i}: {name}\n" for i, name in enumerate(class_names)]) ) data_yaml.write_text(yaml_txt, encoding="utf-8") print("✅ Done.") print(f"- total pairs: {len(pairs)}") print(f"- train: {len(train_pairs)}") print(f"- val: {len(val_pairs)}") print(f"- out: {out_dir}") print(f"- data.yaml: {data_yaml}") if __name__ == "__main__": main()