File size: 5,245 Bytes
e5abc2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Data augmentation utilities for the Emotion Recognition System.
"""
import numpy as np
from typing import Tuple, Optional
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from src.config import AUGMENTATION_CONFIG


def get_augmentation_generator(
    rotation_range: int = AUGMENTATION_CONFIG["rotation_range"],
    width_shift_range: float = AUGMENTATION_CONFIG["width_shift_range"],
    height_shift_range: float = AUGMENTATION_CONFIG["height_shift_range"],
    horizontal_flip: bool = AUGMENTATION_CONFIG["horizontal_flip"],
    zoom_range: float = AUGMENTATION_CONFIG["zoom_range"],
    brightness_range: Tuple[float, float] = AUGMENTATION_CONFIG["brightness_range"],
    fill_mode: str = AUGMENTATION_CONFIG["fill_mode"],
    rescale: float = 1./255
) -> ImageDataGenerator:
    """
    Create an ImageDataGenerator with augmentation settings.
    
    Args:
        rotation_range: Degree range for random rotations
        width_shift_range: Fraction for horizontal shifts
        height_shift_range: Fraction for vertical shifts
        horizontal_flip: Whether to randomly flip horizontally
        zoom_range: Range for random zoom
        brightness_range: Range for brightness adjustment
        fill_mode: Points outside boundaries fill method
        rescale: Rescaling factor
        
    Returns:
        Configured ImageDataGenerator
    """
    return ImageDataGenerator(
        rescale=rescale,
        rotation_range=rotation_range,
        width_shift_range=width_shift_range,
        height_shift_range=height_shift_range,
        horizontal_flip=horizontal_flip,
        zoom_range=zoom_range,
        brightness_range=brightness_range,
        fill_mode=fill_mode
    )


def augment_image(
    image: np.ndarray,
    num_augmentations: int = 5,
    generator: Optional[ImageDataGenerator] = None
) -> np.ndarray:
    """
    Generate augmented versions of a single image.
    
    Args:
        image: Input image array of shape (height, width, channels)
        num_augmentations: Number of augmented images to generate
        generator: Optional ImageDataGenerator, creates default if None
        
    Returns:
        Array of augmented images of shape (num_augmentations, height, width, channels)
    """
    if generator is None:
        generator = get_augmentation_generator(rescale=1.0)  # No rescale for single images
    
    # Reshape for generator (needs batch dimension)
    image_batch = np.expand_dims(image, axis=0)
    
    # Generate augmented images
    augmented_images = []
    aug_iter = generator.flow(image_batch, batch_size=1)
    
    for _ in range(num_augmentations):
        augmented = next(aug_iter)[0]
        augmented_images.append(augmented)
    
    return np.array(augmented_images)


def create_balanced_augmentation(
    images: np.ndarray,
    labels: np.ndarray,
    target_samples_per_class: int
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Create a balanced dataset through augmentation of minority classes.
    
    Args:
        images: Array of images
        labels: Array of one-hot encoded labels
        target_samples_per_class: Target number of samples per class
        
    Returns:
        Tuple of (augmented_images, augmented_labels)
    """
    generator = get_augmentation_generator(rescale=1.0)
    
    # Convert one-hot to class indices
    class_indices = np.argmax(labels, axis=1)
    unique_classes = np.unique(class_indices)
    
    augmented_images = []
    augmented_labels = []
    
    for class_idx in unique_classes:
        # Get images of this class
        class_mask = class_indices == class_idx
        class_images = images[class_mask]
        class_labels = labels[class_mask]
        
        current_count = len(class_images)
        
        # Add original images
        augmented_images.extend(class_images)
        augmented_labels.extend(class_labels)
        
        # Generate more if needed
        if current_count < target_samples_per_class:
            needed = target_samples_per_class - current_count
            
            for i in range(needed):
                # Select random image from class
                idx = np.random.randint(0, current_count)
                original = class_images[idx]
                
                # Generate one augmented version
                aug = augment_image(original, num_augmentations=1, generator=generator)[0]
                augmented_images.append(aug)
                augmented_labels.append(class_labels[idx])
    
    return np.array(augmented_images), np.array(augmented_labels)


def get_augmentation_preview(
    image: np.ndarray,
    num_samples: int = 9
) -> np.ndarray:
    """
    Generate a preview of augmentations for visualization.
    
    Args:
        image: Original image
        num_samples: Number of augmented samples to generate
        
    Returns:
        Array including original + augmented images
    """
    augmented = augment_image(image, num_augmentations=num_samples - 1)
    
    # Add original as first image
    original = np.expand_dims(image, axis=0)
    return np.concatenate([original, augmented], axis=0)