File size: 5,518 Bytes
ff0e79e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""

Dataset-aware augmentation for training

"""

import cv2
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from typing import Dict, Any, Optional


class DatasetAwareAugmentation:
    """Dataset-aware augmentation pipeline"""
    
    def __init__(self, config, dataset_name: str, is_training: bool = True):
        """

        Initialize augmentation pipeline

        

        Args:

            config: Configuration object

            dataset_name: Dataset name

            is_training: Whether in training mode

        """
        self.config = config
        self.dataset_name = dataset_name
        self.is_training = is_training
        
        # Build augmentation pipeline
        self.transform = self._build_transform()
    
    def _build_transform(self) -> A.Compose:
        """Build albumentations transform pipeline"""
        
        transforms = []
        
        if self.is_training and self.config.get('augmentation.enabled', True):
            # Common augmentations
            common_augs = self.config.get('augmentation.common', [])
            
            for aug_config in common_augs:
                aug_type = aug_config.get('type')
                prob = aug_config.get('prob', 0.5)
                
                if aug_type == 'noise':
                    transforms.append(
                        A.GaussNoise(var_limit=(10.0, 50.0), p=prob)
                    )
                
                elif aug_type == 'motion_blur':
                    transforms.append(
                        A.MotionBlur(blur_limit=7, p=prob)
                    )
                
                elif aug_type == 'jpeg_compression':
                    quality_range = aug_config.get('quality', [60, 95])
                    transforms.append(
                        A.ImageCompression(quality_lower=quality_range[0],
                                          quality_upper=quality_range[1],
                                          p=prob)
                    )
                
                elif aug_type == 'lighting':
                    transforms.append(
                        A.OneOf([
                            A.RandomBrightnessContrast(p=1.0),
                            A.RandomGamma(p=1.0),
                            A.HueSaturationValue(p=1.0),
                        ], p=prob)
                    )
                
                elif aug_type == 'perspective':
                    transforms.append(
                        A.Perspective(scale=(0.02, 0.05), p=prob)
                    )
            
            # Dataset-specific augmentations
            if self.dataset_name == 'receipts':
                receipt_augs = self.config.get('augmentation.receipts', [])
                
                for aug_config in receipt_augs:
                    aug_type = aug_config.get('type')
                    prob = aug_config.get('prob', 0.5)
                    
                    if aug_type == 'stain':
                        # Simulate stains using random blobs
                        transforms.append(
                            A.RandomShadow(
                                shadow_roi=(0, 0, 1, 1),
                                num_shadows_lower=1,
                                num_shadows_upper=3,
                                shadow_dimension=5,
                                p=prob
                            )
                        )
                    
                    elif aug_type == 'fold':
                        # Simulate folds using grid distortion
                        transforms.append(
                            A.GridDistortion(num_steps=5, distort_limit=0.1, p=prob)
                        )
        
        # Always convert to tensor
        transforms.append(ToTensorV2())
        
        return A.Compose(
            transforms,
            additional_targets={'mask': 'mask'}
        )
    
    def __call__(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
        """

        Apply augmentation

        

        Args:

            image: Input image (H, W, 3), float32, [0, 1]

            mask: Optional mask (H, W), uint8, {0, 1}

        

        Returns:

            Dictionary with 'image' and optionally 'mask'

        """
        # Convert to uint8 for albumentations
        image_uint8 = (image * 255).astype(np.uint8)
        
        if mask is not None:
            mask_uint8 = (mask * 255).astype(np.uint8)
            augmented = self.transform(image=image_uint8, mask=mask_uint8)
            
            # Convert back to float32
            augmented['image'] = augmented['image'].float() / 255.0
            augmented['mask'] = (augmented['mask'].float() / 255.0).unsqueeze(0)
        else:
            augmented = self.transform(image=image_uint8)
            augmented['image'] = augmented['image'].float() / 255.0
        
        return augmented


def get_augmentation(config, dataset_name: str, is_training: bool = True) -> DatasetAwareAugmentation:
    """

    Get augmentation pipeline

    

    Args:

        config: Configuration object

        dataset_name: Dataset name

        is_training: Whether in training mode

    

    Returns:

        Augmentation pipeline

    """
    return DatasetAwareAugmentation(config, dataset_name, is_training)