|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Export quantized tflite model from a trained checkpoint."""
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import functools
|
| from absl import app
|
| from absl import flags
|
| import tensorflow.compat.v1 as tf
|
| import tensorflow_datasets as tfds
|
| from nets import nets_factory
|
| from preprocessing import preprocessing_factory
|
|
|
| flags.DEFINE_string("model_name", None,
|
| "The name of the architecture to quantize.")
|
| flags.DEFINE_string("checkpoint_path", None, "Path to the training checkpoint.")
|
| flags.DEFINE_string("dataset_name", "imagenet2012",
|
| "Name of the dataset to use for quantization calibration.")
|
| flags.DEFINE_string("dataset_dir", None, "Dataset location.")
|
| flags.DEFINE_string(
|
| "dataset_split", "train",
|
| "The dataset split (train, validation etc.) to use for calibration.")
|
| flags.DEFINE_string("output_tflite", None, "Path to output tflite file.")
|
| flags.DEFINE_boolean(
|
| "use_model_specific_preprocessing", False,
|
| "When true, uses the preprocessing corresponding to the model as specified "
|
| "in preprocessing factory.")
|
| flags.DEFINE_boolean("enable_ema", True,
|
| "Load exponential moving average version of variables.")
|
| flags.DEFINE_integer(
|
| "num_steps", 1000,
|
| "Number of post-training quantization calibration steps to run.")
|
| flags.DEFINE_integer("image_size", 224, "Size of the input image.")
|
| flags.DEFINE_integer("num_classes", 1001,
|
| "Number of output classes for the model.")
|
|
|
| FLAGS = flags.FLAGS
|
|
|
|
|
| _MEAN_RGB = 127.5
|
| _STD_RGB = 127.5
|
|
|
|
|
| def _preprocess_for_quantization(image_data, image_size, crop_padding=32):
|
| """Crops to center of image with padding then scales, normalizes image_size.
|
|
|
| Args:
|
| image_data: A 3D Tensor representing the RGB image data. Image can be of
|
| arbitrary height and width.
|
| image_size: image height/width dimension.
|
| crop_padding: the padding size to use when centering the crop.
|
|
|
| Returns:
|
| A decoded and cropped image Tensor. Image is normalized to [-1,1].
|
|
|
| """
|
|
|
| shape = tf.shape(image_data)
|
| image_height = shape[0]
|
| image_width = shape[1]
|
|
|
| padded_center_crop_size = tf.cast(
|
| (image_size * 1.0 / (image_size + crop_padding)) *
|
| tf.cast(tf.minimum(image_height, image_width), tf.float32), tf.int32)
|
|
|
| offset_height = ((image_height - padded_center_crop_size) + 1) // 2
|
| offset_width = ((image_width - padded_center_crop_size) + 1) // 2
|
|
|
| image = tf.image.crop_to_bounding_box(
|
| image_data,
|
| offset_height=offset_height,
|
| offset_width=offset_width,
|
| target_height=padded_center_crop_size,
|
| target_width=padded_center_crop_size)
|
|
|
| image = tf.image.resize([image], [image_size, image_size],
|
| method=tf.image.ResizeMethod.BICUBIC)[0]
|
| image = tf.cast(image, tf.float32)
|
| image -= tf.constant(_MEAN_RGB)
|
| image /= tf.constant(_STD_RGB)
|
| return image
|
|
|
|
|
| def restore_model(sess, checkpoint_path, enable_ema=True):
|
| """Restore variables from the checkpoint into the provided session.
|
|
|
| Args:
|
| sess: A tensorflow session where the checkpoint will be loaded.
|
| checkpoint_path: Path to the trained checkpoint.
|
| enable_ema: (optional) Whether to load the exponential moving average (ema)
|
| version of the tensorflow variables. Defaults to True.
|
| """
|
| if enable_ema:
|
| ema = tf.train.ExponentialMovingAverage(decay=0.0)
|
| ema_vars = tf.trainable_variables() + tf.get_collection("moving_vars")
|
| for v in tf.global_variables():
|
| if "moving_mean" in v.name or "moving_variance" in v.name:
|
| ema_vars.append(v)
|
| ema_vars = list(set(ema_vars))
|
| var_dict = ema.variables_to_restore(ema_vars)
|
| else:
|
| var_dict = None
|
|
|
| sess.run(tf.global_variables_initializer())
|
| saver = tf.train.Saver(var_dict, max_to_keep=1)
|
| saver.restore(sess, checkpoint_path)
|
|
|
|
|
| def _representative_dataset_gen():
|
| """Gets a python generator of numpy arrays for the given dataset."""
|
| image_size = FLAGS.image_size
|
| dataset = tfds.builder(FLAGS.dataset_name, data_dir=FLAGS.dataset_dir)
|
| dataset.download_and_prepare()
|
| data = dataset.as_dataset()[FLAGS.dataset_split]
|
| iterator = tf.data.make_one_shot_iterator(data)
|
| if FLAGS.use_model_specific_preprocessing:
|
| preprocess_fn = functools.partial(
|
| preprocessing_factory.get_preprocessing(name=FLAGS.model_name),
|
| output_height=image_size,
|
| output_width=image_size)
|
| else:
|
| preprocess_fn = functools.partial(
|
| _preprocess_for_quantization, image_size=image_size)
|
| features = iterator.get_next()
|
| image = features["image"]
|
| image = preprocess_fn(image)
|
| image = tf.reshape(image, [1, image_size, image_size, 3])
|
| for _ in range(FLAGS.num_steps):
|
| yield [image.eval()]
|
|
|
|
|
| def main(_):
|
| with tf.Graph().as_default(), tf.Session() as sess:
|
| network_fn = nets_factory.get_network_fn(
|
| FLAGS.model_name, num_classes=FLAGS.num_classes, is_training=False)
|
| image_size = FLAGS.image_size
|
| images = tf.placeholder(
|
| tf.float32, shape=(1, image_size, image_size, 3), name="images")
|
|
|
| logits, _ = network_fn(images)
|
|
|
| output_tensor = tf.nn.softmax(logits)
|
| restore_model(sess, FLAGS.checkpoint_path, enable_ema=FLAGS.enable_ema)
|
|
|
| converter = tf.lite.TFLiteConverter.from_session(sess, [images],
|
| [output_tensor])
|
|
|
| converter.representative_dataset = tf.lite.RepresentativeDataset(
|
| _representative_dataset_gen)
|
| converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| converter.inference_input_type = tf.int8
|
| converter.inference_output_type = tf.int8
|
| converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
|
|
| tflite_buffer = converter.convert()
|
| with tf.gfile.GFile(FLAGS.output_tflite, "wb") as output_tflite:
|
| output_tflite.write(tflite_buffer)
|
| print("tflite model written to %s" % FLAGS.output_tflite)
|
|
|
|
|
| if __name__ == "__main__":
|
| flags.mark_flag_as_required("model_name")
|
| flags.mark_flag_as_required("checkpoint_path")
|
| flags.mark_flag_as_required("dataset_dir")
|
| flags.mark_flag_as_required("output_tflite")
|
| app.run(main)
|
|
|