| |
| |
| |
| |
| |
| |
| |
| |
| import tensorflow as tf |
| from tensorflow import keras |
| from tensorflow.python.keras import backend |
| from tensorflow.python.keras.engine import base_layer |
| from tensorflow.python.keras.engine import base_preprocessing_layer |
| from tensorflow.python.keras.utils import control_flow_util |
|
|
| from image_classification.tf.src.data_augmentation import data_augmentation |
|
|
| @tf.keras.utils.register_keras_serializable() |
| class DataAugmentationLayer(tf.keras.layers.Layer): |
|
|
| def __init__(self, |
| data_augmentation_fn, |
| config=None, |
| pixels_range=None, |
| batches_per_epoch=None, |
| **kwargs): |
| |
| super(DataAugmentationLayer, self).__init__(**kwargs) |
| self.data_augmentation_fn = data_augmentation_fn |
| self.config_dict = config |
| self.pixels_range = pixels_range |
| self.batches_per_epoch = batches_per_epoch |
| |
| |
| |
| |
| |
| |
| self.batch_info = tf.Variable([0, 0, 0, 0], trainable=False, dtype=tf.int64) |
| self.res = tf.Variable([0, 0], trainable=False, dtype=tf.int64) |
|
|
| |
| |
| try: |
| self.data_augmentation_func = eval("data_augmentation." + data_augmentation_fn) |
| except: |
| raise RuntimeError("Unable to find data augmentation function `{}` in `data_augmentation` package") |
|
|
|
|
| def change_res(self,res): |
| self.res.assign([res[0],res[1]]) |
|
|
| def call(self, inputs, training=True): |
| if training is None: |
| training = backend.learning_phase() |
| inputs = tf.convert_to_tensor(inputs) |
| |
| def transform_input_data(): |
| |
| outputs = self.data_augmentation_func( |
| inputs, |
| self.config_dict, |
| pixels_range=self.pixels_range, |
| batch_info=self.batch_info, |
| current_res=self.res) |
|
|
| |
|
|
| batches_per_epoch_int64 = tf.cast(self.batches_per_epoch,tf.int64) |
| shape_outputs1_int64 = tf.cast(tf.shape(outputs)[1],tf.int64) |
| shape_outputs2_int64 = tf.cast(tf.shape(outputs)[2],tf.int64) |
|
|
| batch = self.batch_info[0] |
| self.batch_info.assign([ |
| batch + 1, |
| (batch + 1) // batches_per_epoch_int64, |
| shape_outputs1_int64, |
| shape_outputs2_int64 |
| ]) |
|
|
| return outputs |
| |
| return control_flow_util.smart_cond(training, transform_input_data, lambda: inputs) |
|
|
|
|
| def get_config(self): |
| config = { |
| 'data_augmentation_fn': self.data_augmentation_fn, |
| 'config': self.config_dict, |
| 'pixels_range': self.pixels_range, |
| 'batches_per_epoch': self.batches_per_epoch, |
| } |
| base_config = super(DataAugmentationLayer, self).get_config() |
| return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|