mnist-digit-classifier / scripts /preprocessing.py
faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
MNIST Preprocessing Pipeline
This module provides PyTorch Dataset and DataLoader setup for MNIST:
- Normalization: Convert uint8 [0, 255] to float32 [0, 1]
- Tensor conversion: numpy arrays to PyTorch tensors
- Channel dimension: (28, 28) -> (1, 28, 28) for CNN input
- Optional transforms for augmentation
Usage:
from scripts.preprocessing import MnistDataset, create_dataloaders
train_dataset = MnistDataset(x_train, y_train, transform=None)
train_loader, val_loader = create_dataloaders(
train_dataset, val_dataset, batch_size=64
)
"""
from typing import Optional, Tuple, List
import numpy as np
from numpy.typing import NDArray
import torch
from torch.utils.data import Dataset, DataLoader
class MnistDataset(Dataset):
"""
PyTorch Dataset for MNIST images.
Handles normalization and conversion to tensors suitable for CNN training.
"""
def __init__(
self,
images: List[NDArray[np.uint8]],
labels: List[int],
transform: Optional[torch.nn.Module] = None
):
"""
Initialize MNIST dataset.
Args:
images: List of 28x28 numpy arrays with pixel values [0, 255]
labels: List of integer labels (0-9)
transform: Optional torchvision transforms for augmentation
"""
self.images = images
self.labels = labels
self.transform = transform
# Validate inputs
assert len(images) == len(labels), \
f"Mismatch: {len(images)} images but {len(labels)} labels"
def __len__(self) -> int:
"""Return number of samples in dataset."""
return len(self.images)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a single sample.
Args:
idx: Index of sample to retrieve
Returns:
Tuple of (image_tensor, label_tensor)
- image_tensor: Shape (1, 28, 28), dtype float32, range [0, 1]
- label_tensor: Shape (), dtype long, value in [0, 9]
"""
# Get image and label
image = np.array(self.images[idx])
label = self.labels[idx]
# Normalize to [0, 1]
image = image.astype(np.float32) / 255.0
# Convert to tensor and add channel dimension: (28, 28) -> (1, 28, 28)
image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)
label = torch.tensor(label, dtype=torch.long)
# Apply transforms if provided (e.g., augmentation)
if self.transform:
image = self.transform(image)
return image, label
def create_dataloaders(
train_dataset: Dataset,
val_dataset: Dataset,
batch_size: int = 64,
num_workers: int = 2,
shuffle_train: bool = True
) -> Tuple[DataLoader, DataLoader]:
"""
Create DataLoader instances for training and validation.
Args:
train_dataset: Training dataset
val_dataset: Validation dataset
batch_size: Number of samples per batch
num_workers: Number of worker processes for data loading
shuffle_train: Whether to shuffle training data
Returns:
Tuple of (train_loader, val_loader)
"""
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=shuffle_train,
num_workers=num_workers,
pin_memory=True # Faster GPU transfer
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False, # No need to shuffle validation
num_workers=num_workers,
pin_memory=True
)
return train_loader, val_loader
def create_test_dataloader(
test_dataset: Dataset,
batch_size: int = 64,
num_workers: int = 2
) -> DataLoader:
"""
Create DataLoader for test set.
Args:
test_dataset: Test dataset
batch_size: Number of samples per batch
num_workers: Number of worker processes for data loading
Returns:
Test DataLoader
"""
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False, # Never shuffle test data
num_workers=num_workers,
pin_memory=True
)
return test_loader
def split_train_val(
images: List[NDArray[np.uint8]],
labels: List[int],
val_split: float = 0.15,
random_seed: int = 42
) -> Tuple[
Tuple[List[NDArray[np.uint8]], List[int]],
Tuple[List[NDArray[np.uint8]], List[int]]
]:
"""
Split training data into train and validation sets.
Uses stratified sampling to maintain class balance.
Args:
images: List of training images
labels: List of training labels
val_split: Fraction of data to use for validation (0.15 = 15%)
random_seed: Random seed for reproducibility
Returns:
Tuple of ((train_images, train_labels), (val_images, val_labels))
"""
from collections import defaultdict
# Group indices by class for stratified split
class_indices = defaultdict(list)
for idx, label in enumerate(labels):
class_indices[label].append(idx)
# Set random seed
np.random.seed(random_seed)
train_indices = []
val_indices = []
# Split each class separately
for class_label, indices in class_indices.items():
indices = np.array(indices)
np.random.shuffle(indices)
split_point = int(len(indices) * (1 - val_split))
train_indices.extend(indices[:split_point])
val_indices.extend(indices[split_point:])
# Shuffle combined indices
np.random.shuffle(train_indices)
np.random.shuffle(val_indices)
# Extract images and labels
train_images = [images[i] for i in train_indices]
train_labels = [labels[i] for i in train_indices]
val_images = [images[i] for i in val_indices]
val_labels = [labels[i] for i in val_indices]
return (train_images, train_labels), (val_images, val_labels)
def get_dataset_statistics(dataset: MnistDataset) -> dict:
"""
Compute statistics for a dataset (useful for debugging).
Args:
dataset: MnistDataset instance
Returns:
Dictionary with statistics
"""
# Sample first image to check preprocessing
sample_img, sample_label = dataset[0]
# Count labels
from collections import Counter
label_counts = Counter([dataset[i][1].item() for i in range(len(dataset))])
return {
'num_samples': len(dataset),
'sample_image_shape': tuple(sample_img.shape),
'sample_image_dtype': str(sample_img.dtype),
'sample_image_range': (sample_img.min().item(), sample_img.max().item()),
'sample_label_dtype': str(sample_label.dtype),
'class_distribution': dict(sorted(label_counts.items()))
}