Spaces:
Sleeping
Sleeping
| """ | |
| Download and prepare TrashNet dataset from GitHub. | |
| Fetches the dataset-resized.zip, converts all images to numpy arrays | |
| with consistent size, and splits into train/val/test sets. | |
| Saved files: | |
| data/processed/X_train.npy, y_train.npy | |
| data/processed/X_val.npy, y_val.npy | |
| data/processed/X_test.npy, y_test.npy | |
| data/processed/classes.npy | |
| """ | |
| import io | |
| import os | |
| import sys | |
| import zipfile | |
| from pathlib import Path | |
| import numpy as np | |
| import requests | |
| import torch | |
| import yaml | |
| from PIL import Image | |
| from sklearn.model_selection import train_test_split | |
| from torch.utils.data import Dataset | |
| # Add the project root to the python path | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| def load_config(config_path="config.yaml"): | |
| with open(config_path, "r") as f: | |
| return yaml.safe_load(f) | |
| config = load_config() | |
| CLASSES = config["classes"] | |
| SPLIT = config["split"] | |
| IMG_SIZE = tuple(config["img_size"]) | |
| RANDOM_SEED = config["random_seed"] | |
| os.makedirs(name="data", exist_ok=True) | |
| GITHUB_URL = "https://github.com/garythung/trashnet/raw/master/data/dataset-resized.zip" | |
| RAW_DIR = Path("data/raw") | |
| OUT_DIR = Path("data/processed") | |
| def download_and_extract() -> Path: | |
| """ | |
| Automates data acquisition for reproducibility. | |
| Downloads the TrashNet ZIP from GitHub and extracts it to data/raw/. | |
| This ensures that anyone running the script gets the exact same starting data. | |
| """ | |
| if (RAW_DIR / "dataset-resized").exists(): | |
| print("[SKIP] Already extracted.") | |
| return RAW_DIR | |
| print("[DOWNLOAD] TrashNet from GitHub...") | |
| try: | |
| response = requests.get(GITHUB_URL, timeout=120) | |
| response.raise_for_status() | |
| with zipfile.ZipFile(io.BytesIO(response.content)) as zf: | |
| zf.extractall(RAW_DIR) | |
| print("[OK] Extraction done.") | |
| except Exception as e: | |
| print(f"[ERROR] Download failed: {e}") | |
| raise | |
| return RAW_DIR | |
| def load_images(raw_path: Path) -> tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Data Standardization: | |
| Reads all images per class, resizes them to a uniform size (IMG_SIZE), | |
| and converts them to RGB. This creates a consistent input format for | |
| the neural network, regardless of the original image dimensions or formats. | |
| """ | |
| images, labels = [], [] | |
| for label_index, class_name in enumerate(CLASSES): | |
| class_dir = raw_path / "dataset-resized" / class_name | |
| if not class_dir.exists(): | |
| print(f"[WARN] Folder not found: {class_dir}, skipping.") | |
| continue | |
| files = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png")) | |
| print(f"[LOAD] {class_name}: {len(files)} images") | |
| for img_path in files: | |
| try: | |
| img = Image.open(img_path).convert("RGB") | |
| img = img.resize(IMG_SIZE) | |
| images.append(np.array(img, dtype=np.uint8)) | |
| labels.append(label_index) | |
| except Exception as e: | |
| print(f"[WARN] Could not read {img_path}: {e}") | |
| return np.array(images), np.array(labels, dtype=np.int64) | |
| def split_and_save(images: np.ndarray, labels: np.ndarray): | |
| """ | |
| Evaluation Rigor: | |
| Splits the data into fixed Train, Validation, and Test sets. | |
| - Train: Used to update model weights. | |
| - Val: Used to tune hyperparameters and prevent overfitting. | |
| - Test: Used for final unbiased evaluation. | |
| Saving as .npy files makes loading much faster during training. | |
| """ | |
| OUT_DIR.mkdir(parents=True, exist_ok=True) | |
| train_ratio, val_ratio, test_ratio = SPLIT | |
| X_train, X_rest, y_train, y_rest = train_test_split( | |
| images, labels, test_size=(1 - train_ratio), stratify=labels, random_state=RANDOM_SEED | |
| ) | |
| val_size = val_ratio / (val_ratio + test_ratio) | |
| X_val, X_test, y_val, y_test = train_test_split( | |
| X_rest, y_rest, test_size=(1 - val_size), stratify=y_rest, random_state=RANDOM_SEED | |
| ) | |
| splits = { | |
| "X_train": X_train, | |
| "y_train": y_train, | |
| "X_val": X_val, | |
| "y_val": y_val, | |
| "X_test": X_test, | |
| "y_test": y_test, | |
| } | |
| for name, array in splits.items(): | |
| path = OUT_DIR / f"{name}.npy" | |
| np.save(path, array) | |
| print(f"[OK] {name}.npy → {array.shape} dtype={array.dtype}") | |
| np.save(OUT_DIR / "classes.npy", np.array(CLASSES)) | |
| print(f"\n[DONE] Splits: " f"Train={len(y_train)} | Val={len(y_val)} | Test={len(y_test)}") | |
| class TrashDataset(Dataset): | |
| """ | |
| The Bridge to PyTorch: | |
| This class is REQUIRED because PyTorch's DataLoader expects a Dataset object. | |
| Why this class? | |
| 1. Efficient Loading: It only loads specific images into RAM when needed (lazy loading). | |
| 2. Data Augmentation: Allows on-the-fly transformations (rotation, flip, etc.) in __getitem__. | |
| 3. Tensor Conversion: Handles the conversion from NumPy arrays to PyTorch Tensors. | |
| """ | |
| def __init__(self, x_path: Path, y_path: Path, transform=None): | |
| """Loads the pre-processed .npy files once into memory.""" | |
| self.X = np.load(x_path) | |
| self.y = torch.from_numpy(np.load(y_path)) | |
| self.transform = transform | |
| def __len__(self): | |
| """Tells the DataLoader how many samples are in the dataset.""" | |
| return len(self.X) | |
| def __getitem__(self, idx): | |
| """ | |
| Fetches a single sample (image + label) at the given index. | |
| This is where preprocessing (transforms) happens during training. | |
| """ | |
| img = self.X[idx] | |
| label = self.y[idx] | |
| if self.transform: | |
| img = self.transform(img) | |
| else: | |
| # Default: Convert [H, W, C] (0-255) to [C, H, W] (0.0-1.0) for | |
| # PyTorch | |
| img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 | |
| return img, label | |
| if __name__ == "__main__": | |
| raw_path = download_and_extract() | |
| images, labels = load_images(raw_path) | |
| split_and_save(images, labels) | |