File size: 3,125 Bytes
3e16037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Data Loading and Preprocessing Module for CIFAR-10 Image Classification.

This module handles:
- Loading the CIFAR-10 dataset from Keras
- Normalizing pixel values to [0, 1]
- One-hot encoding labels
- Data augmentation via ImageDataGenerator
"""

import numpy as np
from tensorflow import keras
from keras.api.datasets import cifar10
from keras.api.utils import to_categorical
from keras.src.legacy.preprocessing.image import ImageDataGenerator


# CIFAR-10 class names (in order of label index 0-9)
CLASS_NAMES = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

NUM_CLASSES = 10
IMG_SHAPE = (32, 32, 3)


def load_cifar10_data():
    """
    Load the CIFAR-10 dataset and return raw train/test splits.

    Returns:
        tuple: (x_train, y_train), (x_test, y_test)
            - x: uint8 images of shape (N, 32, 32, 3)
            - y: integer labels of shape (N, 1)
    """
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    return (x_train, y_train), (x_test, y_test)


def preprocess_data(x_train, y_train, x_test, y_test):
    """
    Normalize images and one-hot encode labels.

    Args:
        x_train: Training images (uint8).
        y_train: Training labels (int).
        x_test: Test images (uint8).
        y_test: Test labels (int).

    Returns:
        tuple: (x_train, y_train, x_test, y_test)
            - x: float32 images normalized to [0, 1]
            - y: one-hot encoded labels of shape (N, 10)
    """
    # Normalize pixel values to [0, 1]
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    # One-hot encode labels
    y_train = to_categorical(y_train, NUM_CLASSES)
    y_test = to_categorical(y_test, NUM_CLASSES)

    return x_train, y_train, x_test, y_test


def create_data_augmentation_generator():
    """
    Create an ImageDataGenerator with augmentation transformations.

    Augmentations applied:
    - Random rotation up to 15 degrees
    - Random width/height shift up to 10%
    - Random horizontal flip
    - Random zoom up to 10%

    Returns:
        ImageDataGenerator: configured generator for training data augmentation.
    """
    datagen = ImageDataGenerator(
        rotation_range=15,
        width_shift_range=0.1,
        height_shift_range=0.1,
        horizontal_flip=True,
        zoom_range=0.1,
        fill_mode='nearest'
    )
    return datagen


def get_prepared_data():
    """
    Full pipeline: load CIFAR-10, preprocess, and return ready-to-use data.

    Returns:
        dict with keys:
            'x_train', 'y_train', 'x_test', 'y_test': processed arrays
            'datagen': ImageDataGenerator for augmented training
    """
    (x_train, y_train), (x_test, y_test) = load_cifar10_data()
    x_train, y_train, x_test, y_test = preprocess_data(
        x_train, y_train, x_test, y_test
    )
    datagen = create_data_augmentation_generator()
    datagen.fit(x_train)

    return {
        'x_train': x_train,
        'y_train': y_train,
        'x_test': x_test,
        'y_test': y_test,
        'datagen': datagen
    }