InklyAI / src /data /augmentation.py
pravinai's picture
Upload folder using huggingface_hub
8eab354 verified
"""
Data augmentation utilities for signature verification training.
"""
import torch
import numpy as np
from typing import Tuple, List, Union
import albumentations as A
from albumentations.pytorch import ToTensorV2
class SignatureAugmentationPipeline:
"""
Comprehensive augmentation pipeline for signature verification.
"""
def __init__(self,
target_size: Tuple[int, int] = (224, 224),
augmentation_strength: str = 'medium'):
"""
Initialize augmentation pipeline.
Args:
target_size: Target size for signature images
augmentation_strength: 'light', 'medium', or 'heavy'
"""
self.target_size = target_size
self.strength = augmentation_strength
# Define augmentation strategies based on strength
self._setup_augmentations()
def _setup_augmentations(self):
"""Setup augmentation transforms based on strength."""
if self.strength == 'light':
self.train_transform = A.Compose([
A.Resize(self.target_size[0], self.target_size[1]),
A.HorizontalFlip(p=0.2),
A.Rotate(limit=5, p=0.3),
A.RandomBrightnessContrast(
brightness_limit=0.1,
contrast_limit=0.1,
p=0.3
),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
elif self.strength == 'medium':
self.train_transform = A.Compose([
A.Resize(self.target_size[0], self.target_size[1]),
A.HorizontalFlip(p=0.3),
A.Rotate(limit=10, p=0.4),
A.RandomBrightnessContrast(
brightness_limit=0.15,
contrast_limit=0.15,
p=0.4
),
A.GaussNoise(var_limit=(5.0, 25.0), p=0.2),
A.ElasticTransform(
alpha=0.5,
sigma=25,
alpha_affine=25,
p=0.2
),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
else: # heavy
self.train_transform = A.Compose([
A.Resize(self.target_size[0], self.target_size[1]),
A.HorizontalFlip(p=0.4),
A.Rotate(limit=15, p=0.5),
A.RandomBrightnessContrast(
brightness_limit=0.2,
contrast_limit=0.2,
p=0.5
),
A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
A.ElasticTransform(
alpha=1,
sigma=50,
alpha_affine=50,
p=0.3
),
A.Perspective(scale=(0.05, 0.1), p=0.2),
A.ShiftScaleRotate(
shift_limit=0.05,
scale_limit=0.1,
rotate_limit=10,
p=0.3
),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# Validation transform (minimal)
self.val_transform = A.Compose([
A.Resize(self.target_size[0], self.target_size[1]),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
def augment_image(self, image: np.ndarray, is_training: bool = True) -> torch.Tensor:
"""
Apply augmentation to a single image.
Args:
image: Input signature image
is_training: Whether to apply training augmentations
Returns:
Augmented image as torch tensor
"""
transform = self.train_transform if is_training else self.val_transform
transformed = transform(image=image)
return transformed['image']
def augment_batch(self, images: List[np.ndarray], is_training: bool = True) -> torch.Tensor:
"""
Apply augmentation to a batch of images.
Args:
images: List of images to augment
is_training: Whether to apply training augmentations
Returns:
Batch of augmented images as torch tensor
"""
augmented_images = []
for image in images:
augmented = self.augment_image(image, is_training)
augmented_images.append(augmented)
return torch.stack(augmented_images)
class PairAugmentation:
"""
Specialized augmentation for signature pairs in Siamese networks.
"""
def __init__(self, target_size: Tuple[int, int] = (224, 224)):
"""
Initialize pair augmentation.
Args:
target_size: Target size for signature images
"""
self.target_size = target_size
# Shared augmentations for both signatures in a pair
self.shared_transform = A.Compose([
A.Resize(self.target_size[0], self.target_size[1]),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# Individual augmentations for each signature
self.individual_transform = A.Compose([
A.HorizontalFlip(p=0.3),
A.Rotate(limit=10, p=0.4),
A.RandomBrightnessContrast(
brightness_limit=0.15,
contrast_limit=0.15,
p=0.4
),
A.GaussNoise(var_limit=(5.0, 25.0), p=0.2),
])
def augment_pair(self,
signature1: np.ndarray,
signature2: np.ndarray,
is_training: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Augment a pair of signatures.
Args:
signature1: First signature image
signature2: Second signature image
is_training: Whether to apply training augmentations
Returns:
Tuple of augmented signature tensors
"""
if is_training:
# Apply individual augmentations
aug1 = self.individual_transform(image=signature1)
aug2 = self.individual_transform(image=signature2)
# Apply shared transformations
final1 = self.shared_transform(image=aug1['image'])
final2 = self.shared_transform(image=aug2['image'])
else:
# Only apply shared transformations for validation
final1 = self.shared_transform(image=signature1)
final2 = self.shared_transform(image=signature2)
return final1['image'], final2['image']
class OnlineAugmentation:
"""
Online augmentation during training for dynamic augmentation.
"""
def __init__(self, target_size: Tuple[int, int] = (224, 224)):
"""
Initialize online augmentation.
Args:
target_size: Target size for signature images
"""
self.target_size = target_size
self.augmentation_pipeline = SignatureAugmentationPipeline(
target_size=target_size,
augmentation_strength='medium'
)
def __call__(self, image: np.ndarray, is_training: bool = True) -> torch.Tensor:
"""
Apply online augmentation.
Args:
image: Input signature image
is_training: Whether to apply training augmentations
Returns:
Augmented image as torch tensor
"""
return self.augmentation_pipeline.augment_image(image, is_training)
def set_strength(self, strength: str):
"""
Dynamically change augmentation strength.
Args:
strength: 'light', 'medium', or 'heavy'
"""
self.augmentation_pipeline = SignatureAugmentationPipeline(
target_size=self.target_size,
augmentation_strength=strength
)