Spaces:
Sleeping
Sleeping
| """ | |
| Data Augmentation Module for MNIST | |
| Provides augmentation transforms to improve model robustness: | |
| - Random rotation (±15°): Simulates varied handwriting angles | |
| - Random translation (±10%): Handles off-center digits | |
| - Random scaling (90-110%): Accounts for different digit sizes | |
| These augmentations are applied on-the-fly during training for infinite variations. | |
| Usage: | |
| from scripts.augmentation import get_train_augmentation, get_val_augmentation | |
| from scripts.preprocessing import MnistDataset | |
| # Training with augmentation | |
| train_dataset = MnistDataset(x_train, y_train, transform=get_train_augmentation()) | |
| # Validation/test without augmentation | |
| val_dataset = MnistDataset(x_val, y_val, transform=get_val_augmentation()) | |
| """ | |
| from torchvision import transforms | |
| import torch | |
| def get_train_augmentation(): | |
| """ | |
| Get augmentation pipeline for training data. | |
| Applies realistic transformations that preserve digit readability: | |
| - Rotation: ±15° (typical handwriting angle variation) | |
| - Translation: ±10% (off-center digits) | |
| - Scaling: 90-110% (size variation) | |
| Note: Normalization happens in MnistDataset, not here. | |
| Returns: | |
| torchvision.transforms.Compose: Composition of augmentation transforms | |
| """ | |
| return transforms.Compose([ | |
| # Random rotation within ±15 degrees | |
| transforms.RandomRotation( | |
| degrees=15, | |
| interpolation=transforms.InterpolationMode.BILINEAR, | |
| fill=0 # Fill with black (background) | |
| ), | |
| # Random translation and scaling (no additional rotation) | |
| transforms.RandomAffine( | |
| degrees=0, # No rotation here (already done above) | |
| translate=(0.1, 0.1), # ±10% horizontal and vertical shift | |
| scale=(0.9, 1.1), # 90-110% zoom | |
| interpolation=transforms.InterpolationMode.BILINEAR, | |
| fill=0 # Fill with black | |
| ), | |
| ]) | |
| def get_val_augmentation(): | |
| """ | |
| Get augmentation pipeline for validation/test data. | |
| No augmentation is applied - returns identity transform. | |
| This ensures fair evaluation on original unmodified data. | |
| Returns: | |
| None (no transforms) | |
| """ | |
| return None | |
| def get_mild_augmentation(): | |
| """ | |
| Get milder augmentation pipeline (conservative settings). | |
| Use this if standard augmentation is too aggressive: | |
| - Rotation: ±10° (reduced from ±15°) | |
| - Translation: ±5% (reduced from ±10%) | |
| - Scaling: 95-105% (reduced from 90-110%) | |
| Returns: | |
| torchvision.transforms.Compose: Mild augmentation transforms | |
| """ | |
| return transforms.Compose([ | |
| transforms.RandomRotation( | |
| degrees=10, | |
| interpolation=transforms.InterpolationMode.BILINEAR, | |
| fill=0 | |
| ), | |
| transforms.RandomAffine( | |
| degrees=0, | |
| translate=(0.05, 0.05), | |
| scale=(0.95, 1.05), | |
| interpolation=transforms.InterpolationMode.BILINEAR, | |
| fill=0 | |
| ), | |
| ]) | |
| def get_aggressive_augmentation(): | |
| """ | |
| Get aggressive augmentation pipeline (stronger settings). | |
| Use with caution - may distort digits beyond recognition: | |
| - Rotation: ±20° | |
| - Translation: ±15% | |
| - Scaling: 80-120% | |
| - Elastic deformation (optional, commented out) | |
| Returns: | |
| torchvision.transforms.Compose: Aggressive augmentation transforms | |
| """ | |
| return transforms.Compose([ | |
| transforms.RandomRotation( | |
| degrees=20, | |
| interpolation=transforms.InterpolationMode.BILINEAR, | |
| fill=0 | |
| ), | |
| transforms.RandomAffine( | |
| degrees=0, | |
| translate=(0.15, 0.15), | |
| scale=(0.8, 1.2), | |
| interpolation=transforms.InterpolationMode.BILINEAR, | |
| fill=0 | |
| ), | |
| # Note: Add elastic deformation if needed | |
| # transforms.ElasticTransform(alpha=34.0, sigma=4.0) | |
| ]) | |
| def visualize_augmentations(image: torch.Tensor, transform, num_samples: int = 9): | |
| """ | |
| Apply augmentation multiple times to visualize variations. | |
| Useful for debugging and understanding augmentation effects. | |
| Args: | |
| image: Single image tensor (1, 28, 28) | |
| transform: Augmentation transform to apply | |
| num_samples: Number of augmented versions to generate | |
| Returns: | |
| list: List of augmented image tensors | |
| """ | |
| augmented_images = [] | |
| for _ in range(num_samples): | |
| if transform: | |
| aug_img = transform(image) | |
| else: | |
| aug_img = image | |
| augmented_images.append(aug_img) | |
| return augmented_images | |
| # Augmentation configuration presets | |
| AUGMENTATION_PRESETS = { | |
| 'none': None, | |
| 'mild': get_mild_augmentation, | |
| 'standard': get_train_augmentation, | |
| 'aggressive': get_aggressive_augmentation | |
| } | |
| def get_augmentation_by_name(preset_name: str = 'standard'): | |
| """ | |
| Get augmentation pipeline by preset name. | |
| Args: | |
| preset_name: One of ['none', 'mild', 'standard', 'aggressive'] | |
| Returns: | |
| Augmentation transform or None | |
| """ | |
| if preset_name not in AUGMENTATION_PRESETS: | |
| raise ValueError( | |
| f"Unknown preset '{preset_name}'. " | |
| f"Choose from: {list(AUGMENTATION_PRESETS.keys())}" | |
| ) | |
| preset = AUGMENTATION_PRESETS[preset_name] | |
| return preset() if callable(preset) else preset | |