| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Train a simple convnet on the MNIST dataset.""" |
| from __future__ import print_function |
|
|
| from absl import app as absl_app |
| from absl import flags |
| import tensorflow as tf |
|
|
| from tensorflow_model_optimization.python.core.keras.compat import keras |
| from tensorflow_model_optimization.python.core.sparsity.keras import prune |
| from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks |
| from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule |
|
|
|
|
| PolynomialDecay = pruning_schedule.PolynomialDecay |
| l = keras.layers |
|
|
| FLAGS = flags.FLAGS |
|
|
| batch_size = 128 |
| num_classes = 10 |
| epochs = 12 |
|
|
| flags.DEFINE_string('output_dir', '/tmp/mnist_train/', |
| 'Output directory to hold tensorboard events') |
|
|
|
|
| def build_sequential_model(input_shape): |
| return keras.Sequential([ |
| l.Conv2D( |
| 32, 5, padding='same', activation='relu', input_shape=input_shape |
| ), |
| l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
| l.BatchNormalization(), |
| l.Conv2D(64, 5, padding='same', activation='relu'), |
| l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
| l.Flatten(), |
| l.Dense(1024, activation='relu'), |
| l.Dropout(0.4), |
| l.Dense(num_classes, activation='softmax'), |
| ]) |
|
|
|
|
| def build_functional_model(input_shape): |
| inp = keras.Input(shape=input_shape) |
| x = l.Conv2D(32, 5, padding='same', activation='relu')(inp) |
| x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) |
| x = l.BatchNormalization()(x) |
| x = l.Conv2D(64, 5, padding='same', activation='relu')(x) |
| x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) |
| x = l.Flatten()(x) |
| x = l.Dense(1024, activation='relu')(x) |
| x = l.Dropout(0.4)(x) |
| out = l.Dense(num_classes, activation='softmax')(x) |
|
|
| return keras.models.Model([inp], [out]) |
|
|
|
|
| def build_layerwise_model(input_shape, **pruning_params): |
| return keras.Sequential([ |
| prune.prune_low_magnitude( |
| l.Conv2D(32, 5, padding='same', activation='relu'), |
| input_shape=input_shape, |
| **pruning_params |
| ), |
| l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
| l.BatchNormalization(), |
| prune.prune_low_magnitude( |
| l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params |
| ), |
| l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
| l.Flatten(), |
| prune.prune_low_magnitude( |
| l.Dense(1024, activation='relu'), **pruning_params |
| ), |
| l.Dropout(0.4), |
| prune.prune_low_magnitude( |
| l.Dense(num_classes, activation='softmax'), **pruning_params |
| ), |
| ]) |
|
|
|
|
| def train_and_save(models, x_train, y_train, x_test, y_test): |
| for model in models: |
| model.compile( |
| loss=keras.losses.categorical_crossentropy, |
| optimizer='adam', |
| metrics=['accuracy'], |
| ) |
|
|
| |
| model.summary() |
|
|
| |
| |
| callbacks = [ |
| pruning_callbacks.UpdatePruningStep(), |
| pruning_callbacks.PruningSummaries(log_dir=FLAGS.output_dir) |
| ] |
|
|
| model.fit( |
| x_train, |
| y_train, |
| batch_size=batch_size, |
| epochs=epochs, |
| verbose=1, |
| callbacks=callbacks, |
| validation_data=(x_test, y_test)) |
| score = model.evaluate(x_test, y_test, verbose=0) |
| print('Test loss:', score[0]) |
| print('Test accuracy:', score[1]) |
|
|
| |
| saved_model_dir = '/tmp/saved_model' |
| print('Saving model to: ', saved_model_dir) |
| keras.models.save_model(model, saved_model_dir, save_format='tf') |
| print('Loading model from: ', saved_model_dir) |
| loaded_model = keras.models.load_model(saved_model_dir) |
|
|
| score = loaded_model.evaluate(x_test, y_test, verbose=0) |
| print('Test loss:', score[0]) |
| print('Test accuracy:', score[1]) |
|
|
|
|
| def main(unused_argv): |
| |
| img_rows, img_cols = 28, 28 |
|
|
| |
| (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() |
|
|
| if keras.backend.image_data_format() == 'channels_first': |
| x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) |
| x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) |
| input_shape = (1, img_rows, img_cols) |
| else: |
| x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) |
| x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) |
| input_shape = (img_rows, img_cols, 1) |
|
|
| x_train = x_train.astype('float32') |
| x_test = x_test.astype('float32') |
| x_train /= 255 |
| x_test /= 255 |
| print('x_train shape:', x_train.shape) |
| print(x_train.shape[0], 'train samples') |
| print(x_test.shape[0], 'test samples') |
|
|
| |
| y_train = keras.utils.to_categorical(y_train, num_classes) |
| y_test = keras.utils.to_categorical(y_test, num_classes) |
|
|
| pruning_params = { |
| 'pruning_schedule': |
| PolynomialDecay( |
| initial_sparsity=0.1, |
| final_sparsity=0.75, |
| begin_step=1000, |
| end_step=5000, |
| frequency=100) |
| } |
|
|
| layerwise_model = build_layerwise_model(input_shape, **pruning_params) |
| sequential_model = build_sequential_model(input_shape) |
| sequential_model = prune.prune_low_magnitude( |
| sequential_model, **pruning_params) |
| functional_model = build_functional_model(input_shape) |
| functional_model = prune.prune_low_magnitude( |
| functional_model, **pruning_params) |
|
|
| models = [layerwise_model, sequential_model, functional_model] |
| train_and_save(models, x_train, y_train, x_test, y_test) |
|
|
|
|
| if __name__ == '__main__': |
| absl_app.run(main) |
|
|