Spaces:
Configuration error
Configuration error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import os | |
| import random | |
| import shutil | |
| from pathlib import Path | |
| IMG_EXTS = [".png", ".jpg", ".jpeg", ".bmp", ".webp"] | |
| def is_image(p: Path) -> bool: | |
| return p.suffix.lower() in IMG_EXTS | |
| def safe_mkdir(p: Path): | |
| p.mkdir(parents=True, exist_ok=True) | |
| def copy_or_link(src: Path, dst: Path, mode: str): | |
| safe_mkdir(dst.parent) | |
| if dst.exists(): | |
| dst.unlink() | |
| if mode == "copy": | |
| shutil.copy2(src, dst) | |
| elif mode == "symlink": | |
| os.symlink(src, dst) | |
| elif mode == "hardlink": | |
| os.link(src, dst) | |
| else: | |
| raise ValueError(f"unknown mode: {mode}") | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--src", default="/home/jeong/data_v3", help="source images dir (contains 1.png~69.png)") | |
| ap.add_argument("--labels", default="/home/jeong/data_v3/labels", help="source labels dir (contains 1.txt~69.txt)") | |
| ap.add_argument("--out", default="/home/jeong/jetank_ws/data_jenga_v3_yolo", help="output dataset root") | |
| ap.add_argument("--val_n", type=int, default=8, help="number of validation images (recommended 6~9 for 69 imgs)") | |
| ap.add_argument("--seed", type=int, default=42, help="random seed for split") | |
| ap.add_argument("--mode", choices=["copy", "symlink", "hardlink"], default="copy", | |
| help="how to place files into output (copy safest; symlink fastest)") | |
| ap.add_argument("--class_names", default="jenga", | |
| help="comma-separated class names for data.yaml (e.g., 'jenga' or 'jenga,jenga_stack')") | |
| args = ap.parse_args() | |
| src_dir = Path(args.src) | |
| lab_dir = Path(args.labels) | |
| out_dir = Path(args.out) | |
| if not src_dir.exists(): | |
| raise FileNotFoundError(f"src not found: {src_dir}") | |
| if not lab_dir.exists(): | |
| raise FileNotFoundError(f"labels not found: {lab_dir}") | |
| # collect images | |
| images = sorted([p for p in src_dir.iterdir() if p.is_file() and is_image(p)]) | |
| if not images: | |
| raise RuntimeError(f"no images found in: {src_dir}") | |
| # match label by stem | |
| pairs = [] | |
| missing_labels = [] | |
| for img in images: | |
| lbl = lab_dir / f"{img.stem}.txt" | |
| if not lbl.exists(): | |
| missing_labels.append(img.name) | |
| continue | |
| pairs.append((img, lbl)) | |
| if missing_labels: | |
| print("[WARN] missing label files for images:") | |
| for n in missing_labels: | |
| print(" -", n) | |
| print("[WARN] those will be skipped.") | |
| if len(pairs) < 2: | |
| raise RuntimeError("not enough (image,label) pairs to split.") | |
| # split | |
| random.seed(args.seed) | |
| random.shuffle(pairs) | |
| val_n = args.val_n | |
| val_n = max(1, min(val_n, len(pairs) - 1)) # ensure at least 1 train | |
| val_pairs = pairs[:val_n] | |
| train_pairs = pairs[val_n:] | |
| # output dirs | |
| img_train = out_dir / "images" / "train" | |
| img_val = out_dir / "images" / "val" | |
| lab_train = out_dir / "labels" / "train" | |
| lab_val = out_dir / "labels" / "val" | |
| for d in [img_train, img_val, lab_train, lab_val]: | |
| safe_mkdir(d) | |
| # copy/link | |
| for img, lbl in train_pairs: | |
| copy_or_link(img, img_train / img.name, args.mode) | |
| copy_or_link(lbl, lab_train / lbl.name, args.mode) | |
| for img, lbl in val_pairs: | |
| copy_or_link(img, img_val / img.name, args.mode) | |
| copy_or_link(lbl, lab_val / lbl.name, args.mode) | |
| # write data.yaml (Ultralytics format) | |
| class_names = [c.strip() for c in args.class_names.split(",") if c.strip()] | |
| if not class_names: | |
| class_names = ["jenga"] | |
| data_yaml = out_dir / "data.yaml" | |
| yaml_txt = ( | |
| f"path: {out_dir}\n" | |
| f"train: images/train\n" | |
| f"val: images/val\n" | |
| f"names:\n" | |
| + "".join([f" {i}: {name}\n" for i, name in enumerate(class_names)]) | |
| ) | |
| data_yaml.write_text(yaml_txt, encoding="utf-8") | |
| print("✅ Done.") | |
| print(f"- total pairs: {len(pairs)}") | |
| print(f"- train: {len(train_pairs)}") | |
| print(f"- val: {len(val_pairs)}") | |
| print(f"- out: {out_dir}") | |
| print(f"- data.yaml: {data_yaml}") | |
| if __name__ == "__main__": | |
| main() | |