""" scripts/download_data.py Download TrashNet + TACO datasets and remap to 5 waste categories. Usage: python scripts/download_data.py --output_dir data/processed TrashNet source: https://github.com/garythung/trashnet TACO source: http://tacodataset.org """ import argparse import logging import os import random import shutil from pathlib import Path from PIL import Image, UnidentifiedImageError logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") log = logging.getLogger(__name__) TRASHNET_MAP = { "plastic": "plastic", "paper": "paper", "cardboard": "paper", "metal": "metal", "glass": "glass", "trash": None, } TACO_MAP = { "food": "organic", "food_waste": "organic", "vegetable": "organic", "fruit": "organic", "organic": "organic", } REALWASTE_MAP = { "cardboard": "paper", "food_organics": "organic", "glass": "glass", "metal": "metal", "paper": "paper", "plastic": "plastic", "vegetation": "organic", } FEEDBACK_MAP = { "plastic": "plastic", "paper": "paper", "organic": "organic", "metal": "metal", "glass": "glass", } LOCAL_BOOST_MAP = FEEDBACK_MAP.copy() TARGET_CLASSES = ["plastic", "paper", "organic", "metal", "glass"] SPLITS = {"train": 0.70, "val": 0.15, "test": 0.15} INPUT_EXTS = {".jpg", ".jpeg", ".png", ".webp"} def phash(path: str, size: int = 16) -> str: """Perceptual hash for duplicate detection.""" try: img = Image.open(path).convert("L").resize((size, size), Image.LANCZOS) pixels = list(img.getdata()) avg = sum(pixels) / len(pixels) return "".join("1" if p > avg else "0" for p in pixels) except Exception: return "" def verify_image(path: str) -> bool: try: with Image.open(path) as img: img.verify() return True except (UnidentifiedImageError, Exception): return False def collect_images(source_dir: str, class_map: dict) -> dict: """Walk source_dir and return {target_class: [abs_path, ...]}.""" collected = {c: [] for c in TARGET_CLASSES} for folder in Path(source_dir).iterdir(): if not folder.is_dir(): continue target = class_map.get(folder.name.lower()) if target is None: continue for file_path in folder.rglob("*"): if file_path.suffix.lower() in INPUT_EXTS: collected[target].append(str(file_path)) return collected def deduplicate(paths: list[str], threshold: int = 8) -> list[str]: seen_hashes = [] unique = [] for path in paths: image_hash = phash(path) if not image_hash: continue is_duplicate = any( sum(a != b for a, b in zip(image_hash, known_hash)) <= threshold for known_hash in seen_hashes ) if not is_duplicate: seen_hashes.append(image_hash) unique.append(path) return unique def reset_output_dir(output_dir: str) -> None: root = Path(output_dir) if not root.exists(): return for split in SPLITS: split_dir = root / split if split_dir.exists(): shutil.rmtree(split_dir) def split_and_copy(images: dict, output_dir: str) -> dict: random.seed(42) stats = {} for cls, paths in images.items(): random.shuffle(paths) total = len(paths) n_train = int(total * SPLITS["train"]) n_val = int(total * SPLITS["val"]) split_paths = { "train": paths[:n_train], "val": paths[n_train:n_train + n_val], "test": paths[n_train + n_val:], } for split, items in split_paths.items(): dest_dir = Path(output_dir) / split / cls dest_dir.mkdir(parents=True, exist_ok=True) for index, src in enumerate(items): ext = Path(src).suffix.lower() dest = dest_dir / f"{cls}_{split}_{index:05d}{ext}" shutil.copy2(src, dest) stats[cls] = {"total": total, **{split: len(items) for split, items in split_paths.items()}} return stats def main(): parser = argparse.ArgumentParser() parser.add_argument( "--trashnet_dir", default="data/raw/trashnet", help="Path to the unzipped TrashNet dataset", ) parser.add_argument( "--taco_dir", default="data/raw/taco", help="Path to the TACO image folders", ) parser.add_argument( "--realwaste_dir", default="data/raw/realwaste", help="Path to the organized RealWaste image folders", ) parser.add_argument( "--feedback_dir", default="data/feedback_labeled", help="Path to operator-reviewed feedback images organized by class", ) parser.add_argument( "--extra_dir", default="data/local_boost", help="Path to extra local training images organized by class", ) parser.add_argument("--output_dir", default="data/processed") args = parser.parse_args() all_images = {c: [] for c in TARGET_CLASSES} if os.path.isdir(args.trashnet_dir): log.info("Collecting from TrashNet...") trashnet_images = collect_images(args.trashnet_dir, TRASHNET_MAP) for cls in TARGET_CLASSES: all_images[cls].extend(trashnet_images[cls]) else: log.warning( "TrashNet dir not found: %s\nDownload from https://github.com/garythung/trashnet and unzip.", args.trashnet_dir, ) if os.path.isdir(args.taco_dir): log.info("Collecting from TACO...") taco_images = collect_images(args.taco_dir, TACO_MAP) for cls in TARGET_CLASSES: all_images[cls].extend(taco_images[cls]) else: log.warning("TACO dir not found: %s. Skipping organic supplement.", args.taco_dir) if os.path.isdir(args.realwaste_dir): log.info("Collecting from RealWaste...") realwaste_images = collect_images(args.realwaste_dir, REALWASTE_MAP) for cls in TARGET_CLASSES: all_images[cls].extend(realwaste_images[cls]) else: log.warning("RealWaste dir not found: %s. Skipping RealWaste supplement.", args.realwaste_dir) if os.path.isdir(args.feedback_dir): log.info("Collecting from operator feedback...") feedback_images = collect_images(args.feedback_dir, FEEDBACK_MAP) for cls in TARGET_CLASSES: all_images[cls].extend(feedback_images[cls]) else: log.warning("Feedback dir not found: %s. Skipping operator feedback supplement.", args.feedback_dir) if os.path.isdir(args.extra_dir): log.info("Collecting from local boost dataset...") local_images = collect_images(args.extra_dir, LOCAL_BOOST_MAP) for cls in TARGET_CLASSES: all_images[cls].extend(local_images[cls]) else: log.warning("Local boost dir not found: %s. Skipping local boost supplement.", args.extra_dir) log.info("Verifying images...") for cls in TARGET_CLASSES: before = len(all_images[cls]) all_images[cls] = [path for path in all_images[cls] if verify_image(path)] removed = before - len(all_images[cls]) if removed: log.warning(" %s: removed %s corrupted files", cls, removed) log.info("Deduplicating...") for cls in TARGET_CLASSES: before = len(all_images[cls]) all_images[cls] = deduplicate(all_images[cls]) log.info(" %s: %s -> %s after dedup", cls, before, len(all_images[cls])) log.info("Splitting and copying...") reset_output_dir(args.output_dir) stats = split_and_copy(all_images, args.output_dir) print("\nDataset summary") print(f"{'Class':<12} {'Total':>7} {'Train':>7} {'Val':>7} {'Test':>7}") print("-" * 44) for cls, summary in stats.items(): print( f"{cls:<12} {summary['total']:>7} {summary['train']:>7} " f"{summary['val']:>7} {summary['test']:>7}" ) grand_total = sum(summary["total"] for summary in stats.values()) print(f"\nTotal images: {grand_total}") print(f"Output dir : {os.path.abspath(args.output_dir)}") if __name__ == "__main__": main()