Spaces:
Sleeping
Sleeping
| """ | |
| MNIST Preprocessing Pipeline | |
| This module provides PyTorch Dataset and DataLoader setup for MNIST: | |
| - Normalization: Convert uint8 [0, 255] to float32 [0, 1] | |
| - Tensor conversion: numpy arrays to PyTorch tensors | |
| - Channel dimension: (28, 28) -> (1, 28, 28) for CNN input | |
| - Optional transforms for augmentation | |
| Usage: | |
| from scripts.preprocessing import MnistDataset, create_dataloaders | |
| train_dataset = MnistDataset(x_train, y_train, transform=None) | |
| train_loader, val_loader = create_dataloaders( | |
| train_dataset, val_dataset, batch_size=64 | |
| ) | |
| """ | |
| from typing import Optional, Tuple, List | |
| import numpy as np | |
| from numpy.typing import NDArray | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| class MnistDataset(Dataset): | |
| """ | |
| PyTorch Dataset for MNIST images. | |
| Handles normalization and conversion to tensors suitable for CNN training. | |
| """ | |
| def __init__( | |
| self, | |
| images: List[NDArray[np.uint8]], | |
| labels: List[int], | |
| transform: Optional[torch.nn.Module] = None | |
| ): | |
| """ | |
| Initialize MNIST dataset. | |
| Args: | |
| images: List of 28x28 numpy arrays with pixel values [0, 255] | |
| labels: List of integer labels (0-9) | |
| transform: Optional torchvision transforms for augmentation | |
| """ | |
| self.images = images | |
| self.labels = labels | |
| self.transform = transform | |
| # Validate inputs | |
| assert len(images) == len(labels), \ | |
| f"Mismatch: {len(images)} images but {len(labels)} labels" | |
| def __len__(self) -> int: | |
| """Return number of samples in dataset.""" | |
| return len(self.images) | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Get a single sample. | |
| Args: | |
| idx: Index of sample to retrieve | |
| Returns: | |
| Tuple of (image_tensor, label_tensor) | |
| - image_tensor: Shape (1, 28, 28), dtype float32, range [0, 1] | |
| - label_tensor: Shape (), dtype long, value in [0, 9] | |
| """ | |
| # Get image and label | |
| image = np.array(self.images[idx]) | |
| label = self.labels[idx] | |
| # Normalize to [0, 1] | |
| image = image.astype(np.float32) / 255.0 | |
| # Convert to tensor and add channel dimension: (28, 28) -> (1, 28, 28) | |
| image = torch.tensor(image, dtype=torch.float32).unsqueeze(0) | |
| label = torch.tensor(label, dtype=torch.long) | |
| # Apply transforms if provided (e.g., augmentation) | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| def create_dataloaders( | |
| train_dataset: Dataset, | |
| val_dataset: Dataset, | |
| batch_size: int = 64, | |
| num_workers: int = 2, | |
| shuffle_train: bool = True | |
| ) -> Tuple[DataLoader, DataLoader]: | |
| """ | |
| Create DataLoader instances for training and validation. | |
| Args: | |
| train_dataset: Training dataset | |
| val_dataset: Validation dataset | |
| batch_size: Number of samples per batch | |
| num_workers: Number of worker processes for data loading | |
| shuffle_train: Whether to shuffle training data | |
| Returns: | |
| Tuple of (train_loader, val_loader) | |
| """ | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle_train, | |
| num_workers=num_workers, | |
| pin_memory=True # Faster GPU transfer | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, # No need to shuffle validation | |
| num_workers=num_workers, | |
| pin_memory=True | |
| ) | |
| return train_loader, val_loader | |
| def create_test_dataloader( | |
| test_dataset: Dataset, | |
| batch_size: int = 64, | |
| num_workers: int = 2 | |
| ) -> DataLoader: | |
| """ | |
| Create DataLoader for test set. | |
| Args: | |
| test_dataset: Test dataset | |
| batch_size: Number of samples per batch | |
| num_workers: Number of worker processes for data loading | |
| Returns: | |
| Test DataLoader | |
| """ | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, # Never shuffle test data | |
| num_workers=num_workers, | |
| pin_memory=True | |
| ) | |
| return test_loader | |
| def split_train_val( | |
| images: List[NDArray[np.uint8]], | |
| labels: List[int], | |
| val_split: float = 0.15, | |
| random_seed: int = 42 | |
| ) -> Tuple[ | |
| Tuple[List[NDArray[np.uint8]], List[int]], | |
| Tuple[List[NDArray[np.uint8]], List[int]] | |
| ]: | |
| """ | |
| Split training data into train and validation sets. | |
| Uses stratified sampling to maintain class balance. | |
| Args: | |
| images: List of training images | |
| labels: List of training labels | |
| val_split: Fraction of data to use for validation (0.15 = 15%) | |
| random_seed: Random seed for reproducibility | |
| Returns: | |
| Tuple of ((train_images, train_labels), (val_images, val_labels)) | |
| """ | |
| from collections import defaultdict | |
| # Group indices by class for stratified split | |
| class_indices = defaultdict(list) | |
| for idx, label in enumerate(labels): | |
| class_indices[label].append(idx) | |
| # Set random seed | |
| np.random.seed(random_seed) | |
| train_indices = [] | |
| val_indices = [] | |
| # Split each class separately | |
| for class_label, indices in class_indices.items(): | |
| indices = np.array(indices) | |
| np.random.shuffle(indices) | |
| split_point = int(len(indices) * (1 - val_split)) | |
| train_indices.extend(indices[:split_point]) | |
| val_indices.extend(indices[split_point:]) | |
| # Shuffle combined indices | |
| np.random.shuffle(train_indices) | |
| np.random.shuffle(val_indices) | |
| # Extract images and labels | |
| train_images = [images[i] for i in train_indices] | |
| train_labels = [labels[i] for i in train_indices] | |
| val_images = [images[i] for i in val_indices] | |
| val_labels = [labels[i] for i in val_indices] | |
| return (train_images, train_labels), (val_images, val_labels) | |
| def get_dataset_statistics(dataset: MnistDataset) -> dict: | |
| """ | |
| Compute statistics for a dataset (useful for debugging). | |
| Args: | |
| dataset: MnistDataset instance | |
| Returns: | |
| Dictionary with statistics | |
| """ | |
| # Sample first image to check preprocessing | |
| sample_img, sample_label = dataset[0] | |
| # Count labels | |
| from collections import Counter | |
| label_counts = Counter([dataset[i][1].item() for i in range(len(dataset))]) | |
| return { | |
| 'num_samples': len(dataset), | |
| 'sample_image_shape': tuple(sample_img.shape), | |
| 'sample_image_dtype': str(sample_img.dtype), | |
| 'sample_image_range': (sample_img.min().item(), sample_img.max().item()), | |
| 'sample_label_dtype': str(sample_label.dtype), | |
| 'class_distribution': dict(sorted(label_counts.items())) | |
| } | |