trash-classifier / src /dataset.py
alshami-dev's picture
First Update to the App
0b86da8 verified
Raw
History Blame Contribute Delete
6.16 kB
"""
Download and prepare TrashNet dataset from GitHub.
Fetches the dataset-resized.zip, converts all images to numpy arrays
with consistent size, and splits into train/val/test sets.
Saved files:
data/processed/X_train.npy, y_train.npy
data/processed/X_val.npy, y_val.npy
data/processed/X_test.npy, y_test.npy
data/processed/classes.npy
"""
import io
import os
import sys
import zipfile
from pathlib import Path
import numpy as np
import requests
import torch
import yaml
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
# Add the project root to the python path
sys.path.append(str(Path(__file__).parent.parent))
def load_config(config_path="config.yaml"):
with open(config_path, "r") as f:
return yaml.safe_load(f)
config = load_config()
CLASSES = config["classes"]
SPLIT = config["split"]
IMG_SIZE = tuple(config["img_size"])
RANDOM_SEED = config["random_seed"]
os.makedirs(name="data", exist_ok=True)
GITHUB_URL = "https://github.com/garythung/trashnet/raw/master/data/dataset-resized.zip"
RAW_DIR = Path("data/raw")
OUT_DIR = Path("data/processed")
def download_and_extract() -> Path:
"""
Automates data acquisition for reproducibility.
Downloads the TrashNet ZIP from GitHub and extracts it to data/raw/.
This ensures that anyone running the script gets the exact same starting data.
"""
if (RAW_DIR / "dataset-resized").exists():
print("[SKIP] Already extracted.")
return RAW_DIR
print("[DOWNLOAD] TrashNet from GitHub...")
try:
response = requests.get(GITHUB_URL, timeout=120)
response.raise_for_status()
with zipfile.ZipFile(io.BytesIO(response.content)) as zf:
zf.extractall(RAW_DIR)
print("[OK] Extraction done.")
except Exception as e:
print(f"[ERROR] Download failed: {e}")
raise
return RAW_DIR
def load_images(raw_path: Path) -> tuple[np.ndarray, np.ndarray]:
"""
Data Standardization:
Reads all images per class, resizes them to a uniform size (IMG_SIZE),
and converts them to RGB. This creates a consistent input format for
the neural network, regardless of the original image dimensions or formats.
"""
images, labels = [], []
for label_index, class_name in enumerate(CLASSES):
class_dir = raw_path / "dataset-resized" / class_name
if not class_dir.exists():
print(f"[WARN] Folder not found: {class_dir}, skipping.")
continue
files = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png"))
print(f"[LOAD] {class_name}: {len(files)} images")
for img_path in files:
try:
img = Image.open(img_path).convert("RGB")
img = img.resize(IMG_SIZE)
images.append(np.array(img, dtype=np.uint8))
labels.append(label_index)
except Exception as e:
print(f"[WARN] Could not read {img_path}: {e}")
return np.array(images), np.array(labels, dtype=np.int64)
def split_and_save(images: np.ndarray, labels: np.ndarray):
"""
Evaluation Rigor:
Splits the data into fixed Train, Validation, and Test sets.
- Train: Used to update model weights.
- Val: Used to tune hyperparameters and prevent overfitting.
- Test: Used for final unbiased evaluation.
Saving as .npy files makes loading much faster during training.
"""
OUT_DIR.mkdir(parents=True, exist_ok=True)
train_ratio, val_ratio, test_ratio = SPLIT
X_train, X_rest, y_train, y_rest = train_test_split(
images, labels, test_size=(1 - train_ratio), stratify=labels, random_state=RANDOM_SEED
)
val_size = val_ratio / (val_ratio + test_ratio)
X_val, X_test, y_val, y_test = train_test_split(
X_rest, y_rest, test_size=(1 - val_size), stratify=y_rest, random_state=RANDOM_SEED
)
splits = {
"X_train": X_train,
"y_train": y_train,
"X_val": X_val,
"y_val": y_val,
"X_test": X_test,
"y_test": y_test,
}
for name, array in splits.items():
path = OUT_DIR / f"{name}.npy"
np.save(path, array)
print(f"[OK] {name}.npy → {array.shape} dtype={array.dtype}")
np.save(OUT_DIR / "classes.npy", np.array(CLASSES))
print(f"\n[DONE] Splits: " f"Train={len(y_train)} | Val={len(y_val)} | Test={len(y_test)}")
class TrashDataset(Dataset):
"""
The Bridge to PyTorch:
This class is REQUIRED because PyTorch's DataLoader expects a Dataset object.
Why this class?
1. Efficient Loading: It only loads specific images into RAM when needed (lazy loading).
2. Data Augmentation: Allows on-the-fly transformations (rotation, flip, etc.) in __getitem__.
3. Tensor Conversion: Handles the conversion from NumPy arrays to PyTorch Tensors.
"""
def __init__(self, x_path: Path, y_path: Path, transform=None):
"""Loads the pre-processed .npy files once into memory."""
self.X = np.load(x_path)
self.y = torch.from_numpy(np.load(y_path))
self.transform = transform
def __len__(self):
"""Tells the DataLoader how many samples are in the dataset."""
return len(self.X)
def __getitem__(self, idx):
"""
Fetches a single sample (image + label) at the given index.
This is where preprocessing (transforms) happens during training.
"""
img = self.X[idx]
label = self.y[idx]
if self.transform:
img = self.transform(img)
else:
# Default: Convert [H, W, C] (0-255) to [C, H, W] (0.0-1.0) for
# PyTorch
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
return img, label
if __name__ == "__main__":
raw_path = download_and_extract()
images, labels = load_images(raw_path)
split_and_save(images, labels)