Spaces:
Running
Running
| """ | |
| scripts/download_data.py | |
| Download TrashNet + TACO datasets and remap to 5 waste categories. | |
| Usage: | |
| python scripts/download_data.py --output_dir data/processed | |
| TrashNet source: https://github.com/garythung/trashnet | |
| TACO source: http://tacodataset.org | |
| """ | |
| import argparse | |
| import logging | |
| import os | |
| import random | |
| import shutil | |
| from pathlib import Path | |
| from PIL import Image, UnidentifiedImageError | |
| logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") | |
| log = logging.getLogger(__name__) | |
| TRASHNET_MAP = { | |
| "plastic": "plastic", | |
| "paper": "paper", | |
| "cardboard": "paper", | |
| "metal": "metal", | |
| "glass": "glass", | |
| "trash": None, | |
| } | |
| TACO_MAP = { | |
| "food": "organic", | |
| "food_waste": "organic", | |
| "vegetable": "organic", | |
| "fruit": "organic", | |
| "organic": "organic", | |
| } | |
| REALWASTE_MAP = { | |
| "cardboard": "paper", | |
| "food_organics": "organic", | |
| "glass": "glass", | |
| "metal": "metal", | |
| "paper": "paper", | |
| "plastic": "plastic", | |
| "vegetation": "organic", | |
| } | |
| FEEDBACK_MAP = { | |
| "plastic": "plastic", | |
| "paper": "paper", | |
| "organic": "organic", | |
| "metal": "metal", | |
| "glass": "glass", | |
| } | |
| LOCAL_BOOST_MAP = FEEDBACK_MAP.copy() | |
| TARGET_CLASSES = ["plastic", "paper", "organic", "metal", "glass"] | |
| SPLITS = {"train": 0.70, "val": 0.15, "test": 0.15} | |
| INPUT_EXTS = {".jpg", ".jpeg", ".png", ".webp"} | |
| def phash(path: str, size: int = 16) -> str: | |
| """Perceptual hash for duplicate detection.""" | |
| try: | |
| img = Image.open(path).convert("L").resize((size, size), Image.LANCZOS) | |
| pixels = list(img.getdata()) | |
| avg = sum(pixels) / len(pixels) | |
| return "".join("1" if p > avg else "0" for p in pixels) | |
| except Exception: | |
| return "" | |
| def verify_image(path: str) -> bool: | |
| try: | |
| with Image.open(path) as img: | |
| img.verify() | |
| return True | |
| except (UnidentifiedImageError, Exception): | |
| return False | |
| def collect_images(source_dir: str, class_map: dict) -> dict: | |
| """Walk source_dir and return {target_class: [abs_path, ...]}.""" | |
| collected = {c: [] for c in TARGET_CLASSES} | |
| for folder in Path(source_dir).iterdir(): | |
| if not folder.is_dir(): | |
| continue | |
| target = class_map.get(folder.name.lower()) | |
| if target is None: | |
| continue | |
| for file_path in folder.rglob("*"): | |
| if file_path.suffix.lower() in INPUT_EXTS: | |
| collected[target].append(str(file_path)) | |
| return collected | |
| def deduplicate(paths: list[str], threshold: int = 8) -> list[str]: | |
| seen_hashes = [] | |
| unique = [] | |
| for path in paths: | |
| image_hash = phash(path) | |
| if not image_hash: | |
| continue | |
| is_duplicate = any( | |
| sum(a != b for a, b in zip(image_hash, known_hash)) <= threshold | |
| for known_hash in seen_hashes | |
| ) | |
| if not is_duplicate: | |
| seen_hashes.append(image_hash) | |
| unique.append(path) | |
| return unique | |
| def reset_output_dir(output_dir: str) -> None: | |
| root = Path(output_dir) | |
| if not root.exists(): | |
| return | |
| for split in SPLITS: | |
| split_dir = root / split | |
| if split_dir.exists(): | |
| shutil.rmtree(split_dir) | |
| def split_and_copy(images: dict, output_dir: str) -> dict: | |
| random.seed(42) | |
| stats = {} | |
| for cls, paths in images.items(): | |
| random.shuffle(paths) | |
| total = len(paths) | |
| n_train = int(total * SPLITS["train"]) | |
| n_val = int(total * SPLITS["val"]) | |
| split_paths = { | |
| "train": paths[:n_train], | |
| "val": paths[n_train:n_train + n_val], | |
| "test": paths[n_train + n_val:], | |
| } | |
| for split, items in split_paths.items(): | |
| dest_dir = Path(output_dir) / split / cls | |
| dest_dir.mkdir(parents=True, exist_ok=True) | |
| for index, src in enumerate(items): | |
| ext = Path(src).suffix.lower() | |
| dest = dest_dir / f"{cls}_{split}_{index:05d}{ext}" | |
| shutil.copy2(src, dest) | |
| stats[cls] = {"total": total, **{split: len(items) for split, items in split_paths.items()}} | |
| return stats | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--trashnet_dir", | |
| default="data/raw/trashnet", | |
| help="Path to the unzipped TrashNet dataset", | |
| ) | |
| parser.add_argument( | |
| "--taco_dir", | |
| default="data/raw/taco", | |
| help="Path to the TACO image folders", | |
| ) | |
| parser.add_argument( | |
| "--realwaste_dir", | |
| default="data/raw/realwaste", | |
| help="Path to the organized RealWaste image folders", | |
| ) | |
| parser.add_argument( | |
| "--feedback_dir", | |
| default="data/feedback_labeled", | |
| help="Path to operator-reviewed feedback images organized by class", | |
| ) | |
| parser.add_argument( | |
| "--extra_dir", | |
| default="data/local_boost", | |
| help="Path to extra local training images organized by class", | |
| ) | |
| parser.add_argument("--output_dir", default="data/processed") | |
| args = parser.parse_args() | |
| all_images = {c: [] for c in TARGET_CLASSES} | |
| if os.path.isdir(args.trashnet_dir): | |
| log.info("Collecting from TrashNet...") | |
| trashnet_images = collect_images(args.trashnet_dir, TRASHNET_MAP) | |
| for cls in TARGET_CLASSES: | |
| all_images[cls].extend(trashnet_images[cls]) | |
| else: | |
| log.warning( | |
| "TrashNet dir not found: %s\nDownload from https://github.com/garythung/trashnet and unzip.", | |
| args.trashnet_dir, | |
| ) | |
| if os.path.isdir(args.taco_dir): | |
| log.info("Collecting from TACO...") | |
| taco_images = collect_images(args.taco_dir, TACO_MAP) | |
| for cls in TARGET_CLASSES: | |
| all_images[cls].extend(taco_images[cls]) | |
| else: | |
| log.warning("TACO dir not found: %s. Skipping organic supplement.", args.taco_dir) | |
| if os.path.isdir(args.realwaste_dir): | |
| log.info("Collecting from RealWaste...") | |
| realwaste_images = collect_images(args.realwaste_dir, REALWASTE_MAP) | |
| for cls in TARGET_CLASSES: | |
| all_images[cls].extend(realwaste_images[cls]) | |
| else: | |
| log.warning("RealWaste dir not found: %s. Skipping RealWaste supplement.", args.realwaste_dir) | |
| if os.path.isdir(args.feedback_dir): | |
| log.info("Collecting from operator feedback...") | |
| feedback_images = collect_images(args.feedback_dir, FEEDBACK_MAP) | |
| for cls in TARGET_CLASSES: | |
| all_images[cls].extend(feedback_images[cls]) | |
| else: | |
| log.warning("Feedback dir not found: %s. Skipping operator feedback supplement.", args.feedback_dir) | |
| if os.path.isdir(args.extra_dir): | |
| log.info("Collecting from local boost dataset...") | |
| local_images = collect_images(args.extra_dir, LOCAL_BOOST_MAP) | |
| for cls in TARGET_CLASSES: | |
| all_images[cls].extend(local_images[cls]) | |
| else: | |
| log.warning("Local boost dir not found: %s. Skipping local boost supplement.", args.extra_dir) | |
| log.info("Verifying images...") | |
| for cls in TARGET_CLASSES: | |
| before = len(all_images[cls]) | |
| all_images[cls] = [path for path in all_images[cls] if verify_image(path)] | |
| removed = before - len(all_images[cls]) | |
| if removed: | |
| log.warning(" %s: removed %s corrupted files", cls, removed) | |
| log.info("Deduplicating...") | |
| for cls in TARGET_CLASSES: | |
| before = len(all_images[cls]) | |
| all_images[cls] = deduplicate(all_images[cls]) | |
| log.info(" %s: %s -> %s after dedup", cls, before, len(all_images[cls])) | |
| log.info("Splitting and copying...") | |
| reset_output_dir(args.output_dir) | |
| stats = split_and_copy(all_images, args.output_dir) | |
| print("\nDataset summary") | |
| print(f"{'Class':<12} {'Total':>7} {'Train':>7} {'Val':>7} {'Test':>7}") | |
| print("-" * 44) | |
| for cls, summary in stats.items(): | |
| print( | |
| f"{cls:<12} {summary['total']:>7} {summary['train']:>7} " | |
| f"{summary['val']:>7} {summary['test']:>7}" | |
| ) | |
| grand_total = sum(summary["total"] for summary in stats.values()) | |
| print(f"\nTotal images: {grand_total}") | |
| print(f"Output dir : {os.path.abspath(args.output_dir)}") | |
| if __name__ == "__main__": | |
| main() | |