Spaces:
Running
Running
File size: 4,949 Bytes
a8aea21 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import shutil
import random
import logging
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S"
)
logger = logging.getLogger(__name__)
# Constants
TARGET_PER_CATEGORY = 1000
SPLIT_RATIO = (0.8, 0.1, 0.1) # Train, Val, Test
DATA_ROOT = Path("data")
PROCESSED_DIR = DATA_ROOT / "processed"
TRAIN_DIR = DATA_ROOT / "train"
VAL_DIR = DATA_ROOT / "val"
TEST_DIR = DATA_ROOT / "test"
def get_image_files(directory):
"""Recursively get all image files in a directory."""
extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp'}
return [f for f in directory.rglob("*") if f.suffix.lower() in extensions and f.is_file()]
def clear_directory(path):
"""Deletes a directory and its contents if it exists."""
if path.exists():
logger.warning(f"Deleting existing directory: {path}")
shutil.rmtree(path)
def main():
logger.info("π Starting Dataset Resplit (v2)")
logger.info(f"π― Target: {TARGET_PER_CATEGORY} images/category | Split: {SPLIT_RATIO}")
# 1. Clear existing splits
clear_directory(TRAIN_DIR)
clear_directory(VAL_DIR)
clear_directory(TEST_DIR)
TRAIN_DIR.mkdir(parents=True, exist_ok=True)
VAL_DIR.mkdir(parents=True, exist_ok=True)
TEST_DIR.mkdir(parents=True, exist_ok=True)
# 2. Iterate through categories in processed
# We assume 'processed' has subfolders like 'workshops/coding', 'workshops/design', etc.
# We walk to find leaf directories that contain images.
# Optimized walker: Only look at files in the current directory
categories = []
for root, dirs, files in os.walk(PROCESSED_DIR):
current_path = Path(root)
# Check files in current dir only
local_images = []
for f in files:
if Path(f).suffix.lower() in {'.jpg', '.jpeg', '.png', '.webp', '.bmp'}:
local_images.append(current_path / f)
if local_images:
# It's a category folder
rel_path = current_path.relative_to(PROCESSED_DIR)
categories.append((rel_path, local_images))
if not categories:
logger.error("β No categories found in data/processed!")
return
logger.info(f"π Found {len(categories)} categories to process.")
for rel_path, images in categories:
category_name = str(rel_path).replace("\\", "/")
logger.info(f"\nπΉ Processing: {category_name}")
# Shuffle and Select
random.shuffle(images)
selected_images = images[:TARGET_PER_CATEGORY]
count = len(selected_images)
if count < TARGET_PER_CATEGORY:
logger.warning(f" β οΈ Only found {count} images (Target: {TARGET_PER_CATEGORY})")
else:
logger.info(f" β
Selected 1000 images from {len(images)} available.")
# Calculate Splits
n_train = int(count * SPLIT_RATIO[0])
n_val = int(count * SPLIT_RATIO[1])
# Give remainder to test to ensure sum == count (or fix strictly if required, but remainder is safer)
n_test = count - n_train - n_val
train_set = selected_images[:n_train]
val_set = selected_images[n_train : n_train + n_val]
test_set = selected_images[n_train + n_val :]
logger.info(f" Splitting: Train={len(train_set)}, Val={len(val_set)}, Test={len(test_set)}")
# Copy Files
for dataset, split_name, dest_root in [
(train_set, "Train", TRAIN_DIR),
(val_set, "Val", VAL_DIR),
(test_set, "Test", TEST_DIR)
]:
if not dataset:
continue
dest_category_dir = dest_root / rel_path
dest_category_dir.mkdir(parents=True, exist_ok=True)
for img_path in dataset:
try:
shutil.copy2(img_path, dest_category_dir / img_path.name)
# Try to copy caption text file if it exists
txt_path = img_path.with_suffix(".txt")
if txt_path.exists():
shutil.copy2(txt_path, dest_category_dir / txt_path.name)
except Exception as e:
logger.error(f"Failed to copy {img_path.name}: {e}")
logger.info("\nπ Resplit Complete.")
# Verification stats
logger.info("π Final Counts:")
for d, name in [(TRAIN_DIR, "TRAIN"), (VAL_DIR, "VAL"), (TEST_DIR, "TEST")]:
total = len(list(d.rglob("*.*"))) # Approx count all files
# Better to count images
img_count = len(get_image_files(d))
logger.info(f" {name}: {img_count} images")
if __name__ == "__main__":
main()
|