""" Dataset-aware augmentation for training """ import cv2 import numpy as np import albumentations as A from albumentations.pytorch import ToTensorV2 from typing import Dict, Any, Optional class DatasetAwareAugmentation: """Dataset-aware augmentation pipeline""" def __init__(self, config, dataset_name: str, is_training: bool = True): """ Initialize augmentation pipeline Args: config: Configuration object dataset_name: Dataset name is_training: Whether in training mode """ self.config = config self.dataset_name = dataset_name self.is_training = is_training # Build augmentation pipeline self.transform = self._build_transform() def _build_transform(self) -> A.Compose: """Build albumentations transform pipeline""" transforms = [] if self.is_training and self.config.get('augmentation.enabled', True): # Common augmentations common_augs = self.config.get('augmentation.common', []) for aug_config in common_augs: aug_type = aug_config.get('type') prob = aug_config.get('prob', 0.5) if aug_type == 'noise': transforms.append( A.GaussNoise(var_limit=(10.0, 50.0), p=prob) ) elif aug_type == 'motion_blur': transforms.append( A.MotionBlur(blur_limit=7, p=prob) ) elif aug_type == 'jpeg_compression': quality_range = aug_config.get('quality', [60, 95]) transforms.append( A.ImageCompression(quality_lower=quality_range[0], quality_upper=quality_range[1], p=prob) ) elif aug_type == 'lighting': transforms.append( A.OneOf([ A.RandomBrightnessContrast(p=1.0), A.RandomGamma(p=1.0), A.HueSaturationValue(p=1.0), ], p=prob) ) elif aug_type == 'perspective': transforms.append( A.Perspective(scale=(0.02, 0.05), p=prob) ) # Dataset-specific augmentations if self.dataset_name == 'receipts': receipt_augs = self.config.get('augmentation.receipts', []) for aug_config in receipt_augs: aug_type = aug_config.get('type') prob = aug_config.get('prob', 0.5) if aug_type == 'stain': # Simulate stains using random blobs transforms.append( A.RandomShadow( shadow_roi=(0, 0, 1, 1), num_shadows_lower=1, num_shadows_upper=3, shadow_dimension=5, p=prob ) ) elif aug_type == 'fold': # Simulate folds using grid distortion transforms.append( A.GridDistortion(num_steps=5, distort_limit=0.1, p=prob) ) # Always convert to tensor transforms.append(ToTensorV2()) return A.Compose( transforms, additional_targets={'mask': 'mask'} ) def __call__(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Dict[str, Any]: """ Apply augmentation Args: image: Input image (H, W, 3), float32, [0, 1] mask: Optional mask (H, W), uint8, {0, 1} Returns: Dictionary with 'image' and optionally 'mask' """ # Convert to uint8 for albumentations image_uint8 = (image * 255).astype(np.uint8) if mask is not None: mask_uint8 = (mask * 255).astype(np.uint8) augmented = self.transform(image=image_uint8, mask=mask_uint8) # Convert back to float32 augmented['image'] = augmented['image'].float() / 255.0 augmented['mask'] = (augmented['mask'].float() / 255.0).unsqueeze(0) else: augmented = self.transform(image=image_uint8) augmented['image'] = augmented['image'].float() / 255.0 return augmented def get_augmentation(config, dataset_name: str, is_training: bool = True) -> DatasetAwareAugmentation: """ Get augmentation pipeline Args: config: Configuration object dataset_name: Dataset name is_training: Whether in training mode Returns: Augmentation pipeline """ return DatasetAwareAugmentation(config, dataset_name, is_training)