RecycloAI / src /prepare_data.py
seun829's picture
Upload 40 files
b5cb408 verified
#!/usr/bin/env python3
"""
Robustly prepare the TrashNet dataset by downloading and splitting into train/val directories.
Usage:
python prepare_trashnet.py [--output-dir OUTPUT_DIR] [--test-size TEST_SIZE] [--seed SEED]
"""
import os
import sys
import argparse
import tempfile
import urllib.request
import zipfile
import random
import shutil
def download_zip(url, dest_path):
try:
print(f"Downloading {url} to {dest_path}...")
urllib.request.urlretrieve(url, dest_path)
except Exception as e:
print(f"Error downloading dataset: {e}", file=sys.stderr)
sys.exit(1)
def extract_zip(zip_path, extract_to):
try:
print(f"Extracting {zip_path} to {extract_to}...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
except Exception as e:
print(f"Error extracting zip file: {e}", file=sys.stderr)
sys.exit(1)
def split_and_copy(source_dir, output_dir, test_size, seed):
random.seed(seed)
classes = [d for d in os.listdir(source_dir)
if os.path.isdir(os.path.join(source_dir, d))]
for cls in classes:
class_dir = os.path.join(source_dir, cls)
images = [f for f in os.listdir(class_dir)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
random.shuffle(images)
split_idx = int(len(images) * (1 - test_size))
train_imgs = images[:split_idx]
val_imgs = images[split_idx:]
for phase, files in [('train', train_imgs), ('val', val_imgs)]:
target_dir = os.path.join(output_dir, phase, cls)
os.makedirs(target_dir, exist_ok=True)
for fname in files:
shutil.copy2(os.path.join(class_dir, fname),
os.path.join(target_dir, fname))
def main():
parser = argparse.ArgumentParser(
description="Prepare TrashNet dataset (download → split)."
)
parser.add_argument(
'--output-dir', '-o', default="data",
help="Output directory for train/val folders (default: './data')."
)
parser.add_argument(
'--test-size', '-t', type=float, default=0.2,
help="Fraction of data for validation (default: 0.2)."
)
parser.add_argument(
'--seed', '-s', type=int, default=42,
help="Random seed for shuffling (default: 42)."
)
args = parser.parse_args()
# Ensure base output exists
os.makedirs(args.output_dir, exist_ok=True)
# 1) Download the GitHub repo ZIP
temp_dir = tempfile.mkdtemp(prefix='trashnet_')
zip_url = 'https://github.com/garythung/trashnet/archive/refs/heads/master.zip'
zip_path = os.path.join(temp_dir, 'trashnet.zip')
download_zip(zip_url, zip_path)
# 2) Extract the repo ZIP
extract_zip(zip_path, temp_dir)
extracted_subdirs = [
d for d in os.listdir(temp_dir)
if os.path.isdir(os.path.join(temp_dir, d))
]
if not extracted_subdirs:
print("No directories found after initial extraction.", file=sys.stderr)
shutil.rmtree(temp_dir)
sys.exit(1)
extracted_root = os.path.join(temp_dir, extracted_subdirs[0])
# 3) Unzip the nested dataset-resized.zip if present
nested_zip = os.path.join(extracted_root, 'data', 'dataset-resized.zip')
if os.path.isfile(nested_zip):
print(f"Found nested zip at {nested_zip}, extracting...")
with zipfile.ZipFile(nested_zip, 'r') as z:
z.extractall(os.path.join(extracted_root, 'data'))
else:
print(f"No nested dataset-resized.zip found at {nested_zip}")
# 4) Auto-locate the folder that contains the class dirs
images_dir = None
for root, dirs, files in os.walk(extracted_root):
# look for the standard TrashNet classes
if {'cardboard','glass','metal','paper','plastic','trash'}.issubset(set(dirs)):
images_dir = root
break
if not images_dir:
print(f"❌ Could not find the image folders under {extracted_root}.", file=sys.stderr)
print("Here’s what was found:")
for root, dirs, _ in os.walk(extracted_root):
print(f"- {root} → subdirs: {dirs}")
shutil.rmtree(temp_dir)
sys.exit(1)
print(f"Using images from: {images_dir}")
# 5) Split into train/val
split_and_copy(images_dir, args.output_dir, args.test_size, args.seed)
print(f"✅ Data prep complete. Train/val splits are in '{args.output_dir}'.")
# Cleanup
shutil.rmtree(temp_dir)
if __name__ == '__main__':
main()