faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
MNIST Data Loader Module
Handles loading MNIST dataset from IDX binary format.
"""
import struct
from array import array
from pathlib import Path
from typing import Tuple, List
import numpy as np
from numpy.typing import NDArray
class MnistDataloader:
"""
Load MNIST handwritten digit dataset from IDX binary files.
The MNIST dataset uses a custom IDX binary format with magic numbers
to identify image (2051) and label (2049) files.
Attributes:
training_images_filepath: Path to training images IDX file
training_labels_filepath: Path to training labels IDX file
test_images_filepath: Path to test images IDX file
test_labels_filepath: Path to test labels IDX file
"""
def __init__(
self,
training_images_filepath: str,
training_labels_filepath: str,
test_images_filepath: str,
test_labels_filepath: str
) -> None:
"""
Initialize MNIST data loader with file paths.
Args:
training_images_filepath: Path to training images (.idx3-ubyte)
training_labels_filepath: Path to training labels (.idx1-ubyte)
test_images_filepath: Path to test images (.idx3-ubyte)
test_labels_filepath: Path to test labels (.idx1-ubyte)
Raises:
FileNotFoundError: If any of the specified files don't exist
"""
self.training_images_filepath = training_images_filepath
self.training_labels_filepath = training_labels_filepath
self.test_images_filepath = test_images_filepath
self.test_labels_filepath = test_labels_filepath
# Verify files exist
for filepath in [
training_images_filepath,
training_labels_filepath,
test_images_filepath,
test_labels_filepath
]:
if not Path(filepath).exists():
raise FileNotFoundError(f"MNIST data file not found: {filepath}")
def read_images_labels(
self,
images_filepath: str,
labels_filepath: str
) -> Tuple[List[NDArray[np.uint8]], List[int]]:
"""
Read images and labels from IDX binary files.
Args:
images_filepath: Path to images IDX file
labels_filepath: Path to labels IDX file
Returns:
Tuple of (images, labels) where:
- images: List of 28x28 numpy arrays (uint8)
- labels: List of integer labels (0-9)
Raises:
ValueError: If magic numbers don't match expected values
"""
# Read labels
labels = []
with open(labels_filepath, 'rb') as file:
magic, size = struct.unpack(">II", file.read(8))
if magic != 2049:
raise ValueError(
f'Magic number mismatch in labels file. '
f'Expected 2049, got {magic}'
)
labels = array("B", file.read())
# Read images
with open(images_filepath, 'rb') as file:
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
if magic != 2051:
raise ValueError(
f'Magic number mismatch in images file. '
f'Expected 2051, got {magic}'
)
image_data = array("B", file.read())
# Convert to list of 28x28 arrays
images = []
for i in range(size):
images.append([0] * rows * cols)
for i in range(size):
img = np.array(
image_data[i * rows * cols:(i + 1) * rows * cols],
dtype=np.uint8
)
img = img.reshape(rows, cols)
images[i][:] = img
return images, list(labels)
def load_data(self) -> Tuple[
Tuple[List[NDArray[np.uint8]], List[int]],
Tuple[List[NDArray[np.uint8]], List[int]]
]:
"""
Load complete MNIST dataset (training and test sets).
Returns:
Tuple of ((x_train, y_train), (x_test, y_test)) where:
- x_train: 60,000 training images (28x28 uint8 arrays)
- y_train: 60,000 training labels (0-9)
- x_test: 10,000 test images (28x28 uint8 arrays)
- y_test: 10,000 test labels (0-9)
Example:
>>> loader = MnistDataloader(
... 'data/raw/train-images.idx3-ubyte',
... 'data/raw/train-labels.idx1-ubyte',
... 'data/raw/t10k-images.idx3-ubyte',
... 'data/raw/t10k-labels.idx1-ubyte'
... )
>>> (x_train, y_train), (x_test, y_test) = loader.load_data()
>>> print(f"Training: {len(x_train)} images")
Training: 60000 images
>>> print(f"Test: {len(x_test)} images")
Test: 10000 images
"""
x_train, y_train = self.read_images_labels(
self.training_images_filepath,
self.training_labels_filepath
)
x_test, y_test = self.read_images_labels(
self.test_images_filepath,
self.test_labels_filepath
)
return (x_train, y_train), (x_test, y_test)