Document_Forgery_Detection / src /data /augmentation.py
JKrishnanandhaa's picture
Upload 54 files
ff0e79e verified
raw
history blame
5.52 kB
"""
Dataset-aware augmentation for training
"""
import cv2
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from typing import Dict, Any, Optional
class DatasetAwareAugmentation:
"""Dataset-aware augmentation pipeline"""
def __init__(self, config, dataset_name: str, is_training: bool = True):
"""
Initialize augmentation pipeline
Args:
config: Configuration object
dataset_name: Dataset name
is_training: Whether in training mode
"""
self.config = config
self.dataset_name = dataset_name
self.is_training = is_training
# Build augmentation pipeline
self.transform = self._build_transform()
def _build_transform(self) -> A.Compose:
"""Build albumentations transform pipeline"""
transforms = []
if self.is_training and self.config.get('augmentation.enabled', True):
# Common augmentations
common_augs = self.config.get('augmentation.common', [])
for aug_config in common_augs:
aug_type = aug_config.get('type')
prob = aug_config.get('prob', 0.5)
if aug_type == 'noise':
transforms.append(
A.GaussNoise(var_limit=(10.0, 50.0), p=prob)
)
elif aug_type == 'motion_blur':
transforms.append(
A.MotionBlur(blur_limit=7, p=prob)
)
elif aug_type == 'jpeg_compression':
quality_range = aug_config.get('quality', [60, 95])
transforms.append(
A.ImageCompression(quality_lower=quality_range[0],
quality_upper=quality_range[1],
p=prob)
)
elif aug_type == 'lighting':
transforms.append(
A.OneOf([
A.RandomBrightnessContrast(p=1.0),
A.RandomGamma(p=1.0),
A.HueSaturationValue(p=1.0),
], p=prob)
)
elif aug_type == 'perspective':
transforms.append(
A.Perspective(scale=(0.02, 0.05), p=prob)
)
# Dataset-specific augmentations
if self.dataset_name == 'receipts':
receipt_augs = self.config.get('augmentation.receipts', [])
for aug_config in receipt_augs:
aug_type = aug_config.get('type')
prob = aug_config.get('prob', 0.5)
if aug_type == 'stain':
# Simulate stains using random blobs
transforms.append(
A.RandomShadow(
shadow_roi=(0, 0, 1, 1),
num_shadows_lower=1,
num_shadows_upper=3,
shadow_dimension=5,
p=prob
)
)
elif aug_type == 'fold':
# Simulate folds using grid distortion
transforms.append(
A.GridDistortion(num_steps=5, distort_limit=0.1, p=prob)
)
# Always convert to tensor
transforms.append(ToTensorV2())
return A.Compose(
transforms,
additional_targets={'mask': 'mask'}
)
def __call__(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
"""
Apply augmentation
Args:
image: Input image (H, W, 3), float32, [0, 1]
mask: Optional mask (H, W), uint8, {0, 1}
Returns:
Dictionary with 'image' and optionally 'mask'
"""
# Convert to uint8 for albumentations
image_uint8 = (image * 255).astype(np.uint8)
if mask is not None:
mask_uint8 = (mask * 255).astype(np.uint8)
augmented = self.transform(image=image_uint8, mask=mask_uint8)
# Convert back to float32
augmented['image'] = augmented['image'].float() / 255.0
augmented['mask'] = (augmented['mask'].float() / 255.0).unsqueeze(0)
else:
augmented = self.transform(image=image_uint8)
augmented['image'] = augmented['image'].float() / 255.0
return augmented
def get_augmentation(config, dataset_name: str, is_training: bool = True) -> DatasetAwareAugmentation:
"""
Get augmentation pipeline
Args:
config: Configuration object
dataset_name: Dataset name
is_training: Whether in training mode
Returns:
Augmentation pipeline
"""
return DatasetAwareAugmentation(config, dataset_name, is_training)