palletizer / train /make_yolo_split.py
jhsjhs8566's picture
Upload 193 files
fbc2aa9 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import random
import shutil
from pathlib import Path
IMG_EXTS = [".png", ".jpg", ".jpeg", ".bmp", ".webp"]
def is_image(p: Path) -> bool:
return p.suffix.lower() in IMG_EXTS
def safe_mkdir(p: Path):
p.mkdir(parents=True, exist_ok=True)
def copy_or_link(src: Path, dst: Path, mode: str):
safe_mkdir(dst.parent)
if dst.exists():
dst.unlink()
if mode == "copy":
shutil.copy2(src, dst)
elif mode == "symlink":
os.symlink(src, dst)
elif mode == "hardlink":
os.link(src, dst)
else:
raise ValueError(f"unknown mode: {mode}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--src", default="/home/jeong/data_v3", help="source images dir (contains 1.png~69.png)")
ap.add_argument("--labels", default="/home/jeong/data_v3/labels", help="source labels dir (contains 1.txt~69.txt)")
ap.add_argument("--out", default="/home/jeong/jetank_ws/data_jenga_v3_yolo", help="output dataset root")
ap.add_argument("--val_n", type=int, default=8, help="number of validation images (recommended 6~9 for 69 imgs)")
ap.add_argument("--seed", type=int, default=42, help="random seed for split")
ap.add_argument("--mode", choices=["copy", "symlink", "hardlink"], default="copy",
help="how to place files into output (copy safest; symlink fastest)")
ap.add_argument("--class_names", default="jenga",
help="comma-separated class names for data.yaml (e.g., 'jenga' or 'jenga,jenga_stack')")
args = ap.parse_args()
src_dir = Path(args.src)
lab_dir = Path(args.labels)
out_dir = Path(args.out)
if not src_dir.exists():
raise FileNotFoundError(f"src not found: {src_dir}")
if not lab_dir.exists():
raise FileNotFoundError(f"labels not found: {lab_dir}")
# collect images
images = sorted([p for p in src_dir.iterdir() if p.is_file() and is_image(p)])
if not images:
raise RuntimeError(f"no images found in: {src_dir}")
# match label by stem
pairs = []
missing_labels = []
for img in images:
lbl = lab_dir / f"{img.stem}.txt"
if not lbl.exists():
missing_labels.append(img.name)
continue
pairs.append((img, lbl))
if missing_labels:
print("[WARN] missing label files for images:")
for n in missing_labels:
print(" -", n)
print("[WARN] those will be skipped.")
if len(pairs) < 2:
raise RuntimeError("not enough (image,label) pairs to split.")
# split
random.seed(args.seed)
random.shuffle(pairs)
val_n = args.val_n
val_n = max(1, min(val_n, len(pairs) - 1)) # ensure at least 1 train
val_pairs = pairs[:val_n]
train_pairs = pairs[val_n:]
# output dirs
img_train = out_dir / "images" / "train"
img_val = out_dir / "images" / "val"
lab_train = out_dir / "labels" / "train"
lab_val = out_dir / "labels" / "val"
for d in [img_train, img_val, lab_train, lab_val]:
safe_mkdir(d)
# copy/link
for img, lbl in train_pairs:
copy_or_link(img, img_train / img.name, args.mode)
copy_or_link(lbl, lab_train / lbl.name, args.mode)
for img, lbl in val_pairs:
copy_or_link(img, img_val / img.name, args.mode)
copy_or_link(lbl, lab_val / lbl.name, args.mode)
# write data.yaml (Ultralytics format)
class_names = [c.strip() for c in args.class_names.split(",") if c.strip()]
if not class_names:
class_names = ["jenga"]
data_yaml = out_dir / "data.yaml"
yaml_txt = (
f"path: {out_dir}\n"
f"train: images/train\n"
f"val: images/val\n"
f"names:\n"
+ "".join([f" {i}: {name}\n" for i, name in enumerate(class_names)])
)
data_yaml.write_text(yaml_txt, encoding="utf-8")
print("✅ Done.")
print(f"- total pairs: {len(pairs)}")
print(f"- train: {len(train_pairs)}")
print(f"- val: {len(val_pairs)}")
print(f"- out: {out_dir}")
print(f"- data.yaml: {data_yaml}")
if __name__ == "__main__":
main()