|
|
|
|
|
"""
|
|
|
Robustly prepare the TrashNet dataset by downloading and splitting into train/val directories.
|
|
|
|
|
|
Usage:
|
|
|
python prepare_trashnet.py [--output-dir OUTPUT_DIR] [--test-size TEST_SIZE] [--seed SEED]
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
import argparse
|
|
|
import tempfile
|
|
|
import urllib.request
|
|
|
import zipfile
|
|
|
import random
|
|
|
import shutil
|
|
|
|
|
|
def download_zip(url, dest_path):
|
|
|
try:
|
|
|
print(f"Downloading {url} to {dest_path}...")
|
|
|
urllib.request.urlretrieve(url, dest_path)
|
|
|
except Exception as e:
|
|
|
print(f"Error downloading dataset: {e}", file=sys.stderr)
|
|
|
sys.exit(1)
|
|
|
|
|
|
def extract_zip(zip_path, extract_to):
|
|
|
try:
|
|
|
print(f"Extracting {zip_path} to {extract_to}...")
|
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
|
|
zip_ref.extractall(extract_to)
|
|
|
except Exception as e:
|
|
|
print(f"Error extracting zip file: {e}", file=sys.stderr)
|
|
|
sys.exit(1)
|
|
|
|
|
|
def split_and_copy(source_dir, output_dir, test_size, seed):
|
|
|
random.seed(seed)
|
|
|
classes = [d for d in os.listdir(source_dir)
|
|
|
if os.path.isdir(os.path.join(source_dir, d))]
|
|
|
for cls in classes:
|
|
|
class_dir = os.path.join(source_dir, cls)
|
|
|
images = [f for f in os.listdir(class_dir)
|
|
|
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
|
|
random.shuffle(images)
|
|
|
split_idx = int(len(images) * (1 - test_size))
|
|
|
train_imgs = images[:split_idx]
|
|
|
val_imgs = images[split_idx:]
|
|
|
for phase, files in [('train', train_imgs), ('val', val_imgs)]:
|
|
|
target_dir = os.path.join(output_dir, phase, cls)
|
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
|
for fname in files:
|
|
|
shutil.copy2(os.path.join(class_dir, fname),
|
|
|
os.path.join(target_dir, fname))
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description="Prepare TrashNet dataset (download → split)."
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
'--output-dir', '-o', default="data",
|
|
|
help="Output directory for train/val folders (default: './data')."
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
'--test-size', '-t', type=float, default=0.2,
|
|
|
help="Fraction of data for validation (default: 0.2)."
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
'--seed', '-s', type=int, default=42,
|
|
|
help="Random seed for shuffling (default: 42)."
|
|
|
)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
temp_dir = tempfile.mkdtemp(prefix='trashnet_')
|
|
|
zip_url = 'https://github.com/garythung/trashnet/archive/refs/heads/master.zip'
|
|
|
zip_path = os.path.join(temp_dir, 'trashnet.zip')
|
|
|
download_zip(zip_url, zip_path)
|
|
|
|
|
|
|
|
|
extract_zip(zip_path, temp_dir)
|
|
|
extracted_subdirs = [
|
|
|
d for d in os.listdir(temp_dir)
|
|
|
if os.path.isdir(os.path.join(temp_dir, d))
|
|
|
]
|
|
|
if not extracted_subdirs:
|
|
|
print("No directories found after initial extraction.", file=sys.stderr)
|
|
|
shutil.rmtree(temp_dir)
|
|
|
sys.exit(1)
|
|
|
extracted_root = os.path.join(temp_dir, extracted_subdirs[0])
|
|
|
|
|
|
|
|
|
nested_zip = os.path.join(extracted_root, 'data', 'dataset-resized.zip')
|
|
|
if os.path.isfile(nested_zip):
|
|
|
print(f"Found nested zip at {nested_zip}, extracting...")
|
|
|
with zipfile.ZipFile(nested_zip, 'r') as z:
|
|
|
z.extractall(os.path.join(extracted_root, 'data'))
|
|
|
else:
|
|
|
print(f"No nested dataset-resized.zip found at {nested_zip}")
|
|
|
|
|
|
|
|
|
images_dir = None
|
|
|
for root, dirs, files in os.walk(extracted_root):
|
|
|
|
|
|
if {'cardboard','glass','metal','paper','plastic','trash'}.issubset(set(dirs)):
|
|
|
images_dir = root
|
|
|
break
|
|
|
|
|
|
if not images_dir:
|
|
|
print(f"❌ Could not find the image folders under {extracted_root}.", file=sys.stderr)
|
|
|
print("Here’s what was found:")
|
|
|
for root, dirs, _ in os.walk(extracted_root):
|
|
|
print(f"- {root} → subdirs: {dirs}")
|
|
|
shutil.rmtree(temp_dir)
|
|
|
sys.exit(1)
|
|
|
|
|
|
print(f"Using images from: {images_dir}")
|
|
|
|
|
|
|
|
|
split_and_copy(images_dir, args.output_dir, args.test_size, args.seed)
|
|
|
print(f"✅ Data prep complete. Train/val splits are in '{args.output_dir}'.")
|
|
|
|
|
|
|
|
|
shutil.rmtree(temp_dir)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|
|
|
|