Spaces:
Sleeping
Sleeping
| """ | |
| MNIST Data Loader Module | |
| Handles loading MNIST dataset from IDX binary format. | |
| """ | |
| import struct | |
| from array import array | |
| from pathlib import Path | |
| from typing import Tuple, List | |
| import numpy as np | |
| from numpy.typing import NDArray | |
| class MnistDataloader: | |
| """ | |
| Load MNIST handwritten digit dataset from IDX binary files. | |
| The MNIST dataset uses a custom IDX binary format with magic numbers | |
| to identify image (2051) and label (2049) files. | |
| Attributes: | |
| training_images_filepath: Path to training images IDX file | |
| training_labels_filepath: Path to training labels IDX file | |
| test_images_filepath: Path to test images IDX file | |
| test_labels_filepath: Path to test labels IDX file | |
| """ | |
| def __init__( | |
| self, | |
| training_images_filepath: str, | |
| training_labels_filepath: str, | |
| test_images_filepath: str, | |
| test_labels_filepath: str | |
| ) -> None: | |
| """ | |
| Initialize MNIST data loader with file paths. | |
| Args: | |
| training_images_filepath: Path to training images (.idx3-ubyte) | |
| training_labels_filepath: Path to training labels (.idx1-ubyte) | |
| test_images_filepath: Path to test images (.idx3-ubyte) | |
| test_labels_filepath: Path to test labels (.idx1-ubyte) | |
| Raises: | |
| FileNotFoundError: If any of the specified files don't exist | |
| """ | |
| self.training_images_filepath = training_images_filepath | |
| self.training_labels_filepath = training_labels_filepath | |
| self.test_images_filepath = test_images_filepath | |
| self.test_labels_filepath = test_labels_filepath | |
| # Verify files exist | |
| for filepath in [ | |
| training_images_filepath, | |
| training_labels_filepath, | |
| test_images_filepath, | |
| test_labels_filepath | |
| ]: | |
| if not Path(filepath).exists(): | |
| raise FileNotFoundError(f"MNIST data file not found: {filepath}") | |
| def read_images_labels( | |
| self, | |
| images_filepath: str, | |
| labels_filepath: str | |
| ) -> Tuple[List[NDArray[np.uint8]], List[int]]: | |
| """ | |
| Read images and labels from IDX binary files. | |
| Args: | |
| images_filepath: Path to images IDX file | |
| labels_filepath: Path to labels IDX file | |
| Returns: | |
| Tuple of (images, labels) where: | |
| - images: List of 28x28 numpy arrays (uint8) | |
| - labels: List of integer labels (0-9) | |
| Raises: | |
| ValueError: If magic numbers don't match expected values | |
| """ | |
| # Read labels | |
| labels = [] | |
| with open(labels_filepath, 'rb') as file: | |
| magic, size = struct.unpack(">II", file.read(8)) | |
| if magic != 2049: | |
| raise ValueError( | |
| f'Magic number mismatch in labels file. ' | |
| f'Expected 2049, got {magic}' | |
| ) | |
| labels = array("B", file.read()) | |
| # Read images | |
| with open(images_filepath, 'rb') as file: | |
| magic, size, rows, cols = struct.unpack(">IIII", file.read(16)) | |
| if magic != 2051: | |
| raise ValueError( | |
| f'Magic number mismatch in images file. ' | |
| f'Expected 2051, got {magic}' | |
| ) | |
| image_data = array("B", file.read()) | |
| # Convert to list of 28x28 arrays | |
| images = [] | |
| for i in range(size): | |
| images.append([0] * rows * cols) | |
| for i in range(size): | |
| img = np.array( | |
| image_data[i * rows * cols:(i + 1) * rows * cols], | |
| dtype=np.uint8 | |
| ) | |
| img = img.reshape(rows, cols) | |
| images[i][:] = img | |
| return images, list(labels) | |
| def load_data(self) -> Tuple[ | |
| Tuple[List[NDArray[np.uint8]], List[int]], | |
| Tuple[List[NDArray[np.uint8]], List[int]] | |
| ]: | |
| """ | |
| Load complete MNIST dataset (training and test sets). | |
| Returns: | |
| Tuple of ((x_train, y_train), (x_test, y_test)) where: | |
| - x_train: 60,000 training images (28x28 uint8 arrays) | |
| - y_train: 60,000 training labels (0-9) | |
| - x_test: 10,000 test images (28x28 uint8 arrays) | |
| - y_test: 10,000 test labels (0-9) | |
| Example: | |
| >>> loader = MnistDataloader( | |
| ... 'data/raw/train-images.idx3-ubyte', | |
| ... 'data/raw/train-labels.idx1-ubyte', | |
| ... 'data/raw/t10k-images.idx3-ubyte', | |
| ... 'data/raw/t10k-labels.idx1-ubyte' | |
| ... ) | |
| >>> (x_train, y_train), (x_test, y_test) = loader.load_data() | |
| >>> print(f"Training: {len(x_train)} images") | |
| Training: 60000 images | |
| >>> print(f"Test: {len(x_test)} images") | |
| Test: 10000 images | |
| """ | |
| x_train, y_train = self.read_images_labels( | |
| self.training_images_filepath, | |
| self.training_labels_filepath | |
| ) | |
| x_test, y_test = self.read_images_labels( | |
| self.test_images_filepath, | |
| self.test_labels_filepath | |
| ) | |
| return (x_train, y_train), (x_test, y_test) | |