File size: 6,160 Bytes
0b86da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""

Download and prepare TrashNet dataset from GitHub.



Fetches the dataset-resized.zip, converts all images to numpy arrays

with consistent size, and splits into train/val/test sets.



Saved files:

    data/processed/X_train.npy, y_train.npy

    data/processed/X_val.npy,   y_val.npy

    data/processed/X_test.npy,  y_test.npy

    data/processed/classes.npy

"""

import io
import os
import sys
import zipfile
from pathlib import Path

import numpy as np
import requests
import torch
import yaml
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset

# Add the project root to the python path
sys.path.append(str(Path(__file__).parent.parent))


def load_config(config_path="config.yaml"):
    with open(config_path, "r") as f:
        return yaml.safe_load(f)


config = load_config()
CLASSES = config["classes"]
SPLIT = config["split"]
IMG_SIZE = tuple(config["img_size"])
RANDOM_SEED = config["random_seed"]

os.makedirs(name="data", exist_ok=True)

GITHUB_URL = "https://github.com/garythung/trashnet/raw/master/data/dataset-resized.zip"
RAW_DIR = Path("data/raw")
OUT_DIR = Path("data/processed")


def download_and_extract() -> Path:
    """

    Automates data acquisition for reproducibility.

    Downloads the TrashNet ZIP from GitHub and extracts it to data/raw/.

    This ensures that anyone running the script gets the exact same starting data.

    """
    if (RAW_DIR / "dataset-resized").exists():
        print("[SKIP] Already extracted.")
        return RAW_DIR

    print("[DOWNLOAD] TrashNet from GitHub...")

    try:
        response = requests.get(GITHUB_URL, timeout=120)
        response.raise_for_status()

        with zipfile.ZipFile(io.BytesIO(response.content)) as zf:
            zf.extractall(RAW_DIR)
            print("[OK] Extraction done.")
    except Exception as e:
        print(f"[ERROR] Download failed: {e}")
        raise

    return RAW_DIR


def load_images(raw_path: Path) -> tuple[np.ndarray, np.ndarray]:
    """

    Data Standardization:

    Reads all images per class, resizes them to a uniform size (IMG_SIZE),

    and converts them to RGB. This creates a consistent input format for

    the neural network, regardless of the original image dimensions or formats.

    """
    images, labels = [], []

    for label_index, class_name in enumerate(CLASSES):
        class_dir = raw_path / "dataset-resized" / class_name
        if not class_dir.exists():
            print(f"[WARN] Folder not found: {class_dir}, skipping.")
            continue

        files = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png"))
        print(f"[LOAD] {class_name}: {len(files)} images")

        for img_path in files:
            try:
                img = Image.open(img_path).convert("RGB")
                img = img.resize(IMG_SIZE)
                images.append(np.array(img, dtype=np.uint8))
                labels.append(label_index)
            except Exception as e:
                print(f"[WARN] Could not read {img_path}: {e}")

    return np.array(images), np.array(labels, dtype=np.int64)


def split_and_save(images: np.ndarray, labels: np.ndarray):
    """

    Evaluation Rigor:

    Splits the data into fixed Train, Validation, and Test sets.

    - Train: Used to update model weights.

    - Val: Used to tune hyperparameters and prevent overfitting.

    - Test: Used for final unbiased evaluation.

    Saving as .npy files makes loading much faster during training.

    """

    OUT_DIR.mkdir(parents=True, exist_ok=True)
    train_ratio, val_ratio, test_ratio = SPLIT

    X_train, X_rest, y_train, y_rest = train_test_split(
        images, labels, test_size=(1 - train_ratio), stratify=labels, random_state=RANDOM_SEED
    )

    val_size = val_ratio / (val_ratio + test_ratio)
    X_val, X_test, y_val, y_test = train_test_split(
        X_rest, y_rest, test_size=(1 - val_size), stratify=y_rest, random_state=RANDOM_SEED
    )

    splits = {
        "X_train": X_train,
        "y_train": y_train,
        "X_val": X_val,
        "y_val": y_val,
        "X_test": X_test,
        "y_test": y_test,
    }

    for name, array in splits.items():
        path = OUT_DIR / f"{name}.npy"
        np.save(path, array)
        print(f"[OK] {name}.npy → {array.shape}  dtype={array.dtype}")

    np.save(OUT_DIR / "classes.npy", np.array(CLASSES))
    print(f"\n[DONE] Splits: " f"Train={len(y_train)} | Val={len(y_val)} | Test={len(y_test)}")


class TrashDataset(Dataset):
    """

    The Bridge to PyTorch:

    This class is REQUIRED because PyTorch's DataLoader expects a Dataset object.



    Why this class?

    1. Efficient Loading: It only loads specific images into RAM when needed (lazy loading).

    2. Data Augmentation: Allows on-the-fly transformations (rotation, flip, etc.) in __getitem__.

    3. Tensor Conversion: Handles the conversion from NumPy arrays to PyTorch Tensors.

    """

    def __init__(self, x_path: Path, y_path: Path, transform=None):
        """Loads the pre-processed .npy files once into memory."""
        self.X = np.load(x_path)
        self.y = torch.from_numpy(np.load(y_path))
        self.transform = transform

    def __len__(self):
        """Tells the DataLoader how many samples are in the dataset."""
        return len(self.X)

    def __getitem__(self, idx):
        """

        Fetches a single sample (image + label) at the given index.

        This is where preprocessing (transforms) happens during training.

        """
        img = self.X[idx]
        label = self.y[idx]

        if self.transform:
            img = self.transform(img)
        else:
            # Default: Convert [H, W, C] (0-255) to [C, H, W] (0.0-1.0) for
            # PyTorch
            img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0

        return img, label


if __name__ == "__main__":
    raw_path = download_and_extract()
    images, labels = load_images(raw_path)
    split_and_save(images, labels)