mnist-digit-classifier / scripts /augmentation.py
faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
Data Augmentation Module for MNIST
Provides augmentation transforms to improve model robustness:
- Random rotation (±15°): Simulates varied handwriting angles
- Random translation (±10%): Handles off-center digits
- Random scaling (90-110%): Accounts for different digit sizes
These augmentations are applied on-the-fly during training for infinite variations.
Usage:
from scripts.augmentation import get_train_augmentation, get_val_augmentation
from scripts.preprocessing import MnistDataset
# Training with augmentation
train_dataset = MnistDataset(x_train, y_train, transform=get_train_augmentation())
# Validation/test without augmentation
val_dataset = MnistDataset(x_val, y_val, transform=get_val_augmentation())
"""
from torchvision import transforms
import torch
def get_train_augmentation():
"""
Get augmentation pipeline for training data.
Applies realistic transformations that preserve digit readability:
- Rotation: ±15° (typical handwriting angle variation)
- Translation: ±10% (off-center digits)
- Scaling: 90-110% (size variation)
Note: Normalization happens in MnistDataset, not here.
Returns:
torchvision.transforms.Compose: Composition of augmentation transforms
"""
return transforms.Compose([
# Random rotation within ±15 degrees
transforms.RandomRotation(
degrees=15,
interpolation=transforms.InterpolationMode.BILINEAR,
fill=0 # Fill with black (background)
),
# Random translation and scaling (no additional rotation)
transforms.RandomAffine(
degrees=0, # No rotation here (already done above)
translate=(0.1, 0.1), # ±10% horizontal and vertical shift
scale=(0.9, 1.1), # 90-110% zoom
interpolation=transforms.InterpolationMode.BILINEAR,
fill=0 # Fill with black
),
])
def get_val_augmentation():
"""
Get augmentation pipeline for validation/test data.
No augmentation is applied - returns identity transform.
This ensures fair evaluation on original unmodified data.
Returns:
None (no transforms)
"""
return None
def get_mild_augmentation():
"""
Get milder augmentation pipeline (conservative settings).
Use this if standard augmentation is too aggressive:
- Rotation: ±10° (reduced from ±15°)
- Translation: ±5% (reduced from ±10%)
- Scaling: 95-105% (reduced from 90-110%)
Returns:
torchvision.transforms.Compose: Mild augmentation transforms
"""
return transforms.Compose([
transforms.RandomRotation(
degrees=10,
interpolation=transforms.InterpolationMode.BILINEAR,
fill=0
),
transforms.RandomAffine(
degrees=0,
translate=(0.05, 0.05),
scale=(0.95, 1.05),
interpolation=transforms.InterpolationMode.BILINEAR,
fill=0
),
])
def get_aggressive_augmentation():
"""
Get aggressive augmentation pipeline (stronger settings).
Use with caution - may distort digits beyond recognition:
- Rotation: ±20°
- Translation: ±15%
- Scaling: 80-120%
- Elastic deformation (optional, commented out)
Returns:
torchvision.transforms.Compose: Aggressive augmentation transforms
"""
return transforms.Compose([
transforms.RandomRotation(
degrees=20,
interpolation=transforms.InterpolationMode.BILINEAR,
fill=0
),
transforms.RandomAffine(
degrees=0,
translate=(0.15, 0.15),
scale=(0.8, 1.2),
interpolation=transforms.InterpolationMode.BILINEAR,
fill=0
),
# Note: Add elastic deformation if needed
# transforms.ElasticTransform(alpha=34.0, sigma=4.0)
])
def visualize_augmentations(image: torch.Tensor, transform, num_samples: int = 9):
"""
Apply augmentation multiple times to visualize variations.
Useful for debugging and understanding augmentation effects.
Args:
image: Single image tensor (1, 28, 28)
transform: Augmentation transform to apply
num_samples: Number of augmented versions to generate
Returns:
list: List of augmented image tensors
"""
augmented_images = []
for _ in range(num_samples):
if transform:
aug_img = transform(image)
else:
aug_img = image
augmented_images.append(aug_img)
return augmented_images
# Augmentation configuration presets
AUGMENTATION_PRESETS = {
'none': None,
'mild': get_mild_augmentation,
'standard': get_train_augmentation,
'aggressive': get_aggressive_augmentation
}
def get_augmentation_by_name(preset_name: str = 'standard'):
"""
Get augmentation pipeline by preset name.
Args:
preset_name: One of ['none', 'mild', 'standard', 'aggressive']
Returns:
Augmentation transform or None
"""
if preset_name not in AUGMENTATION_PRESETS:
raise ValueError(
f"Unknown preset '{preset_name}'. "
f"Choose from: {list(AUGMENTATION_PRESETS.keys())}"
)
preset = AUGMENTATION_PRESETS[preset_name]
return preset() if callable(preset) else preset