File size: 3,976 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022-2023 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_preprocessing_layer
from tensorflow.python.keras.utils import control_flow_util

from image_classification.tf.src.data_augmentation import data_augmentation

@tf.keras.utils.register_keras_serializable()
class DataAugmentationLayer(tf.keras.layers.Layer):

    def __init__(self,
                data_augmentation_fn,
                config=None,
                pixels_range=None,
                batches_per_epoch=None,
                **kwargs):
        # base_preprocessing_layer.keras_kpl_gauge.get_cell('DataAugmentationLayer').set(True)
        super(DataAugmentationLayer, self).__init__(**kwargs)
        self.data_augmentation_fn = data_augmentation_fn
        self.config_dict = config
        self.pixels_range = pixels_range
        self.batches_per_epoch = batches_per_epoch
        # Get the data augmentation
        # Variable used to keep track of batch info:
        #   batch_info[0]   batch number since beginning of training
        #   batch_info[1]   epoch
        #   batch_info[2]   width of the previous image
        #   batch_info[3]   height of the previous image
        self.batch_info = tf.Variable([0, 0, 0, 0], trainable=False, dtype=tf.int64)
        self.res = tf.Variable([0, 0], trainable=False, dtype=tf.int64)

        # Get the data augmentation function the layer will call
        # every time it receives a batch of images to augment
        try:
            self.data_augmentation_func = eval("data_augmentation." + data_augmentation_fn)
        except:
            raise RuntimeError("Unable to find data augmentation function `{}` in `data_augmentation` package")


    def change_res(self,res):
        self.res.assign([res[0],res[1]])

    def call(self, inputs, training=True):
        if training is None:
            training = backend.learning_phase()
        inputs = tf.convert_to_tensor(inputs)
            
        def transform_input_data():
            # Call the user-defined data augmentation function
            outputs = self.data_augmentation_func(
                    inputs,
                    self.config_dict,
                    pixels_range=self.pixels_range,
                    batch_info=self.batch_info,
                    current_res=self.res)

            # Record the batch info

            batches_per_epoch_int64 = tf.cast(self.batches_per_epoch,tf.int64)
            shape_outputs1_int64 = tf.cast(tf.shape(outputs)[1],tf.int64)
            shape_outputs2_int64 = tf.cast(tf.shape(outputs)[2],tf.int64)

            batch = self.batch_info[0]
            self.batch_info.assign([
                    batch + 1,
                    (batch + 1) // batches_per_epoch_int64,
                    shape_outputs1_int64,
                    shape_outputs2_int64
                ])

            return outputs
        
        return control_flow_util.smart_cond(training, transform_input_data, lambda: inputs)


    def get_config(self):
        config = {
            'data_augmentation_fn': self.data_augmentation_fn,
            'config': self.config_dict,
            'pixels_range': self.pixels_range,
            'batches_per_epoch': self.batches_per_epoch,
        }
        base_config = super(DataAugmentationLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))