""" MNIST Preprocessing Pipeline This module provides PyTorch Dataset and DataLoader setup for MNIST: - Normalization: Convert uint8 [0, 255] to float32 [0, 1] - Tensor conversion: numpy arrays to PyTorch tensors - Channel dimension: (28, 28) -> (1, 28, 28) for CNN input - Optional transforms for augmentation Usage: from scripts.preprocessing import MnistDataset, create_dataloaders train_dataset = MnistDataset(x_train, y_train, transform=None) train_loader, val_loader = create_dataloaders( train_dataset, val_dataset, batch_size=64 ) """ from typing import Optional, Tuple, List import numpy as np from numpy.typing import NDArray import torch from torch.utils.data import Dataset, DataLoader class MnistDataset(Dataset): """ PyTorch Dataset for MNIST images. Handles normalization and conversion to tensors suitable for CNN training. """ def __init__( self, images: List[NDArray[np.uint8]], labels: List[int], transform: Optional[torch.nn.Module] = None ): """ Initialize MNIST dataset. Args: images: List of 28x28 numpy arrays with pixel values [0, 255] labels: List of integer labels (0-9) transform: Optional torchvision transforms for augmentation """ self.images = images self.labels = labels self.transform = transform # Validate inputs assert len(images) == len(labels), \ f"Mismatch: {len(images)} images but {len(labels)} labels" def __len__(self) -> int: """Return number of samples in dataset.""" return len(self.images) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Get a single sample. Args: idx: Index of sample to retrieve Returns: Tuple of (image_tensor, label_tensor) - image_tensor: Shape (1, 28, 28), dtype float32, range [0, 1] - label_tensor: Shape (), dtype long, value in [0, 9] """ # Get image and label image = np.array(self.images[idx]) label = self.labels[idx] # Normalize to [0, 1] image = image.astype(np.float32) / 255.0 # Convert to tensor and add channel dimension: (28, 28) -> (1, 28, 28) image = torch.tensor(image, dtype=torch.float32).unsqueeze(0) label = torch.tensor(label, dtype=torch.long) # Apply transforms if provided (e.g., augmentation) if self.transform: image = self.transform(image) return image, label def create_dataloaders( train_dataset: Dataset, val_dataset: Dataset, batch_size: int = 64, num_workers: int = 2, shuffle_train: bool = True ) -> Tuple[DataLoader, DataLoader]: """ Create DataLoader instances for training and validation. Args: train_dataset: Training dataset val_dataset: Validation dataset batch_size: Number of samples per batch num_workers: Number of worker processes for data loading shuffle_train: Whether to shuffle training data Returns: Tuple of (train_loader, val_loader) """ train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers, pin_memory=True # Faster GPU transfer ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, # No need to shuffle validation num_workers=num_workers, pin_memory=True ) return train_loader, val_loader def create_test_dataloader( test_dataset: Dataset, batch_size: int = 64, num_workers: int = 2 ) -> DataLoader: """ Create DataLoader for test set. Args: test_dataset: Test dataset batch_size: Number of samples per batch num_workers: Number of worker processes for data loading Returns: Test DataLoader """ test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, # Never shuffle test data num_workers=num_workers, pin_memory=True ) return test_loader def split_train_val( images: List[NDArray[np.uint8]], labels: List[int], val_split: float = 0.15, random_seed: int = 42 ) -> Tuple[ Tuple[List[NDArray[np.uint8]], List[int]], Tuple[List[NDArray[np.uint8]], List[int]] ]: """ Split training data into train and validation sets. Uses stratified sampling to maintain class balance. Args: images: List of training images labels: List of training labels val_split: Fraction of data to use for validation (0.15 = 15%) random_seed: Random seed for reproducibility Returns: Tuple of ((train_images, train_labels), (val_images, val_labels)) """ from collections import defaultdict # Group indices by class for stratified split class_indices = defaultdict(list) for idx, label in enumerate(labels): class_indices[label].append(idx) # Set random seed np.random.seed(random_seed) train_indices = [] val_indices = [] # Split each class separately for class_label, indices in class_indices.items(): indices = np.array(indices) np.random.shuffle(indices) split_point = int(len(indices) * (1 - val_split)) train_indices.extend(indices[:split_point]) val_indices.extend(indices[split_point:]) # Shuffle combined indices np.random.shuffle(train_indices) np.random.shuffle(val_indices) # Extract images and labels train_images = [images[i] for i in train_indices] train_labels = [labels[i] for i in train_indices] val_images = [images[i] for i in val_indices] val_labels = [labels[i] for i in val_indices] return (train_images, train_labels), (val_images, val_labels) def get_dataset_statistics(dataset: MnistDataset) -> dict: """ Compute statistics for a dataset (useful for debugging). Args: dataset: MnistDataset instance Returns: Dictionary with statistics """ # Sample first image to check preprocessing sample_img, sample_label = dataset[0] # Count labels from collections import Counter label_counts = Counter([dataset[i][1].item() for i in range(len(dataset))]) return { 'num_samples': len(dataset), 'sample_image_shape': tuple(sample_img.shape), 'sample_image_dtype': str(sample_img.dtype), 'sample_image_range': (sample_img.min().item(), sample_img.max().item()), 'sample_label_dtype': str(sample_label.dtype), 'class_distribution': dict(sorted(label_counts.items())) }