CoolWasteAI / scripts /download_data.py
Celvin
Prepare deployable AI API for competition and free hosting
12d831f
"""
scripts/download_data.py
Download TrashNet + TACO datasets and remap to 5 waste categories.
Usage:
python scripts/download_data.py --output_dir data/processed
TrashNet source: https://github.com/garythung/trashnet
TACO source: http://tacodataset.org
"""
import argparse
import logging
import os
import random
import shutil
from pathlib import Path
from PIL import Image, UnidentifiedImageError
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
log = logging.getLogger(__name__)
TRASHNET_MAP = {
"plastic": "plastic",
"paper": "paper",
"cardboard": "paper",
"metal": "metal",
"glass": "glass",
"trash": None,
}
TACO_MAP = {
"food": "organic",
"food_waste": "organic",
"vegetable": "organic",
"fruit": "organic",
"organic": "organic",
}
REALWASTE_MAP = {
"cardboard": "paper",
"food_organics": "organic",
"glass": "glass",
"metal": "metal",
"paper": "paper",
"plastic": "plastic",
"vegetation": "organic",
}
FEEDBACK_MAP = {
"plastic": "plastic",
"paper": "paper",
"organic": "organic",
"metal": "metal",
"glass": "glass",
}
LOCAL_BOOST_MAP = FEEDBACK_MAP.copy()
TARGET_CLASSES = ["plastic", "paper", "organic", "metal", "glass"]
SPLITS = {"train": 0.70, "val": 0.15, "test": 0.15}
INPUT_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
def phash(path: str, size: int = 16) -> str:
"""Perceptual hash for duplicate detection."""
try:
img = Image.open(path).convert("L").resize((size, size), Image.LANCZOS)
pixels = list(img.getdata())
avg = sum(pixels) / len(pixels)
return "".join("1" if p > avg else "0" for p in pixels)
except Exception:
return ""
def verify_image(path: str) -> bool:
try:
with Image.open(path) as img:
img.verify()
return True
except (UnidentifiedImageError, Exception):
return False
def collect_images(source_dir: str, class_map: dict) -> dict:
"""Walk source_dir and return {target_class: [abs_path, ...]}."""
collected = {c: [] for c in TARGET_CLASSES}
for folder in Path(source_dir).iterdir():
if not folder.is_dir():
continue
target = class_map.get(folder.name.lower())
if target is None:
continue
for file_path in folder.rglob("*"):
if file_path.suffix.lower() in INPUT_EXTS:
collected[target].append(str(file_path))
return collected
def deduplicate(paths: list[str], threshold: int = 8) -> list[str]:
seen_hashes = []
unique = []
for path in paths:
image_hash = phash(path)
if not image_hash:
continue
is_duplicate = any(
sum(a != b for a, b in zip(image_hash, known_hash)) <= threshold
for known_hash in seen_hashes
)
if not is_duplicate:
seen_hashes.append(image_hash)
unique.append(path)
return unique
def reset_output_dir(output_dir: str) -> None:
root = Path(output_dir)
if not root.exists():
return
for split in SPLITS:
split_dir = root / split
if split_dir.exists():
shutil.rmtree(split_dir)
def split_and_copy(images: dict, output_dir: str) -> dict:
random.seed(42)
stats = {}
for cls, paths in images.items():
random.shuffle(paths)
total = len(paths)
n_train = int(total * SPLITS["train"])
n_val = int(total * SPLITS["val"])
split_paths = {
"train": paths[:n_train],
"val": paths[n_train:n_train + n_val],
"test": paths[n_train + n_val:],
}
for split, items in split_paths.items():
dest_dir = Path(output_dir) / split / cls
dest_dir.mkdir(parents=True, exist_ok=True)
for index, src in enumerate(items):
ext = Path(src).suffix.lower()
dest = dest_dir / f"{cls}_{split}_{index:05d}{ext}"
shutil.copy2(src, dest)
stats[cls] = {"total": total, **{split: len(items) for split, items in split_paths.items()}}
return stats
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--trashnet_dir",
default="data/raw/trashnet",
help="Path to the unzipped TrashNet dataset",
)
parser.add_argument(
"--taco_dir",
default="data/raw/taco",
help="Path to the TACO image folders",
)
parser.add_argument(
"--realwaste_dir",
default="data/raw/realwaste",
help="Path to the organized RealWaste image folders",
)
parser.add_argument(
"--feedback_dir",
default="data/feedback_labeled",
help="Path to operator-reviewed feedback images organized by class",
)
parser.add_argument(
"--extra_dir",
default="data/local_boost",
help="Path to extra local training images organized by class",
)
parser.add_argument("--output_dir", default="data/processed")
args = parser.parse_args()
all_images = {c: [] for c in TARGET_CLASSES}
if os.path.isdir(args.trashnet_dir):
log.info("Collecting from TrashNet...")
trashnet_images = collect_images(args.trashnet_dir, TRASHNET_MAP)
for cls in TARGET_CLASSES:
all_images[cls].extend(trashnet_images[cls])
else:
log.warning(
"TrashNet dir not found: %s\nDownload from https://github.com/garythung/trashnet and unzip.",
args.trashnet_dir,
)
if os.path.isdir(args.taco_dir):
log.info("Collecting from TACO...")
taco_images = collect_images(args.taco_dir, TACO_MAP)
for cls in TARGET_CLASSES:
all_images[cls].extend(taco_images[cls])
else:
log.warning("TACO dir not found: %s. Skipping organic supplement.", args.taco_dir)
if os.path.isdir(args.realwaste_dir):
log.info("Collecting from RealWaste...")
realwaste_images = collect_images(args.realwaste_dir, REALWASTE_MAP)
for cls in TARGET_CLASSES:
all_images[cls].extend(realwaste_images[cls])
else:
log.warning("RealWaste dir not found: %s. Skipping RealWaste supplement.", args.realwaste_dir)
if os.path.isdir(args.feedback_dir):
log.info("Collecting from operator feedback...")
feedback_images = collect_images(args.feedback_dir, FEEDBACK_MAP)
for cls in TARGET_CLASSES:
all_images[cls].extend(feedback_images[cls])
else:
log.warning("Feedback dir not found: %s. Skipping operator feedback supplement.", args.feedback_dir)
if os.path.isdir(args.extra_dir):
log.info("Collecting from local boost dataset...")
local_images = collect_images(args.extra_dir, LOCAL_BOOST_MAP)
for cls in TARGET_CLASSES:
all_images[cls].extend(local_images[cls])
else:
log.warning("Local boost dir not found: %s. Skipping local boost supplement.", args.extra_dir)
log.info("Verifying images...")
for cls in TARGET_CLASSES:
before = len(all_images[cls])
all_images[cls] = [path for path in all_images[cls] if verify_image(path)]
removed = before - len(all_images[cls])
if removed:
log.warning(" %s: removed %s corrupted files", cls, removed)
log.info("Deduplicating...")
for cls in TARGET_CLASSES:
before = len(all_images[cls])
all_images[cls] = deduplicate(all_images[cls])
log.info(" %s: %s -> %s after dedup", cls, before, len(all_images[cls]))
log.info("Splitting and copying...")
reset_output_dir(args.output_dir)
stats = split_and_copy(all_images, args.output_dir)
print("\nDataset summary")
print(f"{'Class':<12} {'Total':>7} {'Train':>7} {'Val':>7} {'Test':>7}")
print("-" * 44)
for cls, summary in stats.items():
print(
f"{cls:<12} {summary['total']:>7} {summary['train']:>7} "
f"{summary['val']:>7} {summary['test']:>7}"
)
grand_total = sum(summary["total"] for summary in stats.values())
print(f"\nTotal images: {grand_total}")
print(f"Output dir : {os.path.abspath(args.output_dir)}")
if __name__ == "__main__":
main()