# /*--------------------------------------------------------------------------------------------- # * Copyright (c) 2022-2023 STMicroelectronics. # * All rights reserved. # * # * This software is licensed under terms that can be found in the LICENSE file in # * the root directory of this software component. # * If no LICENSE file comes with this software, it is provided AS-IS. # *--------------------------------------------------------------------------------------------*/ import tensorflow as tf from tensorflow import keras from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_preprocessing_layer from tensorflow.python.keras.utils import control_flow_util from image_classification.tf.src.data_augmentation import data_augmentation @tf.keras.utils.register_keras_serializable() class DataAugmentationLayer(tf.keras.layers.Layer): def __init__(self, data_augmentation_fn, config=None, pixels_range=None, batches_per_epoch=None, **kwargs): # base_preprocessing_layer.keras_kpl_gauge.get_cell('DataAugmentationLayer').set(True) super(DataAugmentationLayer, self).__init__(**kwargs) self.data_augmentation_fn = data_augmentation_fn self.config_dict = config self.pixels_range = pixels_range self.batches_per_epoch = batches_per_epoch # Get the data augmentation # Variable used to keep track of batch info: # batch_info[0] batch number since beginning of training # batch_info[1] epoch # batch_info[2] width of the previous image # batch_info[3] height of the previous image self.batch_info = tf.Variable([0, 0, 0, 0], trainable=False, dtype=tf.int64) self.res = tf.Variable([0, 0], trainable=False, dtype=tf.int64) # Get the data augmentation function the layer will call # every time it receives a batch of images to augment try: self.data_augmentation_func = eval("data_augmentation." + data_augmentation_fn) except: raise RuntimeError("Unable to find data augmentation function `{}` in `data_augmentation` package") def change_res(self,res): self.res.assign([res[0],res[1]]) def call(self, inputs, training=True): if training is None: training = backend.learning_phase() inputs = tf.convert_to_tensor(inputs) def transform_input_data(): # Call the user-defined data augmentation function outputs = self.data_augmentation_func( inputs, self.config_dict, pixels_range=self.pixels_range, batch_info=self.batch_info, current_res=self.res) # Record the batch info batches_per_epoch_int64 = tf.cast(self.batches_per_epoch,tf.int64) shape_outputs1_int64 = tf.cast(tf.shape(outputs)[1],tf.int64) shape_outputs2_int64 = tf.cast(tf.shape(outputs)[2],tf.int64) batch = self.batch_info[0] self.batch_info.assign([ batch + 1, (batch + 1) // batches_per_epoch_int64, shape_outputs1_int64, shape_outputs2_int64 ]) return outputs return control_flow_util.smart_cond(training, transform_input_data, lambda: inputs) def get_config(self): config = { 'data_augmentation_fn': self.data_augmentation_fn, 'config': self.config_dict, 'pixels_range': self.pixels_range, 'batches_per_epoch': self.batches_per_epoch, } base_config = super(DataAugmentationLayer, self).get_config() return dict(list(base_config.items()) + list(config.items()))