""" Download and prepare TrashNet dataset from GitHub. Fetches the dataset-resized.zip, converts all images to numpy arrays with consistent size, and splits into train/val/test sets. Saved files: data/processed/X_train.npy, y_train.npy data/processed/X_val.npy, y_val.npy data/processed/X_test.npy, y_test.npy data/processed/classes.npy """ import io import os import sys import zipfile from pathlib import Path import numpy as np import requests import torch import yaml from PIL import Image from sklearn.model_selection import train_test_split from torch.utils.data import Dataset # Add the project root to the python path sys.path.append(str(Path(__file__).parent.parent)) def load_config(config_path="config.yaml"): with open(config_path, "r") as f: return yaml.safe_load(f) config = load_config() CLASSES = config["classes"] SPLIT = config["split"] IMG_SIZE = tuple(config["img_size"]) RANDOM_SEED = config["random_seed"] os.makedirs(name="data", exist_ok=True) GITHUB_URL = "https://github.com/garythung/trashnet/raw/master/data/dataset-resized.zip" RAW_DIR = Path("data/raw") OUT_DIR = Path("data/processed") def download_and_extract() -> Path: """ Automates data acquisition for reproducibility. Downloads the TrashNet ZIP from GitHub and extracts it to data/raw/. This ensures that anyone running the script gets the exact same starting data. """ if (RAW_DIR / "dataset-resized").exists(): print("[SKIP] Already extracted.") return RAW_DIR print("[DOWNLOAD] TrashNet from GitHub...") try: response = requests.get(GITHUB_URL, timeout=120) response.raise_for_status() with zipfile.ZipFile(io.BytesIO(response.content)) as zf: zf.extractall(RAW_DIR) print("[OK] Extraction done.") except Exception as e: print(f"[ERROR] Download failed: {e}") raise return RAW_DIR def load_images(raw_path: Path) -> tuple[np.ndarray, np.ndarray]: """ Data Standardization: Reads all images per class, resizes them to a uniform size (IMG_SIZE), and converts them to RGB. This creates a consistent input format for the neural network, regardless of the original image dimensions or formats. """ images, labels = [], [] for label_index, class_name in enumerate(CLASSES): class_dir = raw_path / "dataset-resized" / class_name if not class_dir.exists(): print(f"[WARN] Folder not found: {class_dir}, skipping.") continue files = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png")) print(f"[LOAD] {class_name}: {len(files)} images") for img_path in files: try: img = Image.open(img_path).convert("RGB") img = img.resize(IMG_SIZE) images.append(np.array(img, dtype=np.uint8)) labels.append(label_index) except Exception as e: print(f"[WARN] Could not read {img_path}: {e}") return np.array(images), np.array(labels, dtype=np.int64) def split_and_save(images: np.ndarray, labels: np.ndarray): """ Evaluation Rigor: Splits the data into fixed Train, Validation, and Test sets. - Train: Used to update model weights. - Val: Used to tune hyperparameters and prevent overfitting. - Test: Used for final unbiased evaluation. Saving as .npy files makes loading much faster during training. """ OUT_DIR.mkdir(parents=True, exist_ok=True) train_ratio, val_ratio, test_ratio = SPLIT X_train, X_rest, y_train, y_rest = train_test_split( images, labels, test_size=(1 - train_ratio), stratify=labels, random_state=RANDOM_SEED ) val_size = val_ratio / (val_ratio + test_ratio) X_val, X_test, y_val, y_test = train_test_split( X_rest, y_rest, test_size=(1 - val_size), stratify=y_rest, random_state=RANDOM_SEED ) splits = { "X_train": X_train, "y_train": y_train, "X_val": X_val, "y_val": y_val, "X_test": X_test, "y_test": y_test, } for name, array in splits.items(): path = OUT_DIR / f"{name}.npy" np.save(path, array) print(f"[OK] {name}.npy → {array.shape} dtype={array.dtype}") np.save(OUT_DIR / "classes.npy", np.array(CLASSES)) print(f"\n[DONE] Splits: " f"Train={len(y_train)} | Val={len(y_val)} | Test={len(y_test)}") class TrashDataset(Dataset): """ The Bridge to PyTorch: This class is REQUIRED because PyTorch's DataLoader expects a Dataset object. Why this class? 1. Efficient Loading: It only loads specific images into RAM when needed (lazy loading). 2. Data Augmentation: Allows on-the-fly transformations (rotation, flip, etc.) in __getitem__. 3. Tensor Conversion: Handles the conversion from NumPy arrays to PyTorch Tensors. """ def __init__(self, x_path: Path, y_path: Path, transform=None): """Loads the pre-processed .npy files once into memory.""" self.X = np.load(x_path) self.y = torch.from_numpy(np.load(y_path)) self.transform = transform def __len__(self): """Tells the DataLoader how many samples are in the dataset.""" return len(self.X) def __getitem__(self, idx): """ Fetches a single sample (image + label) at the given index. This is where preprocessing (transforms) happens during training. """ img = self.X[idx] label = self.y[idx] if self.transform: img = self.transform(img) else: # Default: Convert [H, W, C] (0-255) to [C, H, W] (0.0-1.0) for # PyTorch img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 return img, label if __name__ == "__main__": raw_path = download_and_extract() images, labels = load_images(raw_path) split_and_save(images, labels)