""" Data Augmentation Module for MNIST Provides augmentation transforms to improve model robustness: - Random rotation (±15°): Simulates varied handwriting angles - Random translation (±10%): Handles off-center digits - Random scaling (90-110%): Accounts for different digit sizes These augmentations are applied on-the-fly during training for infinite variations. Usage: from scripts.augmentation import get_train_augmentation, get_val_augmentation from scripts.preprocessing import MnistDataset # Training with augmentation train_dataset = MnistDataset(x_train, y_train, transform=get_train_augmentation()) # Validation/test without augmentation val_dataset = MnistDataset(x_val, y_val, transform=get_val_augmentation()) """ from torchvision import transforms import torch def get_train_augmentation(): """ Get augmentation pipeline for training data. Applies realistic transformations that preserve digit readability: - Rotation: ±15° (typical handwriting angle variation) - Translation: ±10% (off-center digits) - Scaling: 90-110% (size variation) Note: Normalization happens in MnistDataset, not here. Returns: torchvision.transforms.Compose: Composition of augmentation transforms """ return transforms.Compose([ # Random rotation within ±15 degrees transforms.RandomRotation( degrees=15, interpolation=transforms.InterpolationMode.BILINEAR, fill=0 # Fill with black (background) ), # Random translation and scaling (no additional rotation) transforms.RandomAffine( degrees=0, # No rotation here (already done above) translate=(0.1, 0.1), # ±10% horizontal and vertical shift scale=(0.9, 1.1), # 90-110% zoom interpolation=transforms.InterpolationMode.BILINEAR, fill=0 # Fill with black ), ]) def get_val_augmentation(): """ Get augmentation pipeline for validation/test data. No augmentation is applied - returns identity transform. This ensures fair evaluation on original unmodified data. Returns: None (no transforms) """ return None def get_mild_augmentation(): """ Get milder augmentation pipeline (conservative settings). Use this if standard augmentation is too aggressive: - Rotation: ±10° (reduced from ±15°) - Translation: ±5% (reduced from ±10%) - Scaling: 95-105% (reduced from 90-110%) Returns: torchvision.transforms.Compose: Mild augmentation transforms """ return transforms.Compose([ transforms.RandomRotation( degrees=10, interpolation=transforms.InterpolationMode.BILINEAR, fill=0 ), transforms.RandomAffine( degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05), interpolation=transforms.InterpolationMode.BILINEAR, fill=0 ), ]) def get_aggressive_augmentation(): """ Get aggressive augmentation pipeline (stronger settings). Use with caution - may distort digits beyond recognition: - Rotation: ±20° - Translation: ±15% - Scaling: 80-120% - Elastic deformation (optional, commented out) Returns: torchvision.transforms.Compose: Aggressive augmentation transforms """ return transforms.Compose([ transforms.RandomRotation( degrees=20, interpolation=transforms.InterpolationMode.BILINEAR, fill=0 ), transforms.RandomAffine( degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), interpolation=transforms.InterpolationMode.BILINEAR, fill=0 ), # Note: Add elastic deformation if needed # transforms.ElasticTransform(alpha=34.0, sigma=4.0) ]) def visualize_augmentations(image: torch.Tensor, transform, num_samples: int = 9): """ Apply augmentation multiple times to visualize variations. Useful for debugging and understanding augmentation effects. Args: image: Single image tensor (1, 28, 28) transform: Augmentation transform to apply num_samples: Number of augmented versions to generate Returns: list: List of augmented image tensors """ augmented_images = [] for _ in range(num_samples): if transform: aug_img = transform(image) else: aug_img = image augmented_images.append(aug_img) return augmented_images # Augmentation configuration presets AUGMENTATION_PRESETS = { 'none': None, 'mild': get_mild_augmentation, 'standard': get_train_augmentation, 'aggressive': get_aggressive_augmentation } def get_augmentation_by_name(preset_name: str = 'standard'): """ Get augmentation pipeline by preset name. Args: preset_name: One of ['none', 'mild', 'standard', 'aggressive'] Returns: Augmentation transform or None """ if preset_name not in AUGMENTATION_PRESETS: raise ValueError( f"Unknown preset '{preset_name}'. " f"Choose from: {list(AUGMENTATION_PRESETS.keys())}" ) preset = AUGMENTATION_PRESETS[preset_name] return preset() if callable(preset) else preset