|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Tools to convert a quantized deeplab model to tflite."""
|
|
|
| from absl import app
|
| from absl import flags
|
| import numpy as np
|
| from PIL import Image
|
| import tensorflow as tf
|
|
|
|
|
| flags.DEFINE_string('quantized_graph_def_path', None,
|
| 'Path to quantized graphdef.')
|
| flags.DEFINE_string('output_tflite_path', None, 'Output TFlite model path.')
|
| flags.DEFINE_string(
|
| 'input_tensor_name', None,
|
| 'Input tensor to TFlite model. This usually should be the input tensor to '
|
| 'model backbone.'
|
| )
|
| flags.DEFINE_string(
|
| 'output_tensor_name', 'ArgMax:0',
|
| 'Output tensor name of TFlite model. By default we output the raw semantic '
|
| 'label predictions.'
|
| )
|
| flags.DEFINE_string(
|
| 'test_image_path', None,
|
| 'Path to an image to test the consistency between input graphdef / '
|
| 'converted tflite model.'
|
| )
|
|
|
| FLAGS = flags.FLAGS
|
|
|
|
|
| def convert_to_tflite(quantized_graphdef,
|
| backbone_input_tensor,
|
| output_tensor):
|
| """Helper method to convert quantized deeplab model to TFlite."""
|
| with tf.Graph().as_default() as graph:
|
| tf.graph_util.import_graph_def(quantized_graphdef, name='')
|
| sess = tf.compat.v1.Session()
|
|
|
| tflite_input = graph.get_tensor_by_name(backbone_input_tensor)
|
| tflite_output = graph.get_tensor_by_name(output_tensor)
|
| converter = tf.compat.v1.lite.TFLiteConverter.from_session(
|
| sess, [tflite_input], [tflite_output])
|
| converter.inference_type = tf.compat.v1.lite.constants.QUANTIZED_UINT8
|
| input_arrays = converter.get_input_arrays()
|
| converter.quantized_input_stats = {input_arrays[0]: (127.5, 127.5)}
|
| return converter.convert()
|
|
|
|
|
| def check_tflite_consistency(graph_def, tflite_model, image_path):
|
| """Runs tflite and frozen graph on same input, check their outputs match."""
|
|
|
| interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
| interpreter.allocate_tensors()
|
| input_details = interpreter.get_input_details()
|
| output_details = interpreter.get_output_details()
|
| height, width = input_details[0]['shape'][1:3]
|
|
|
|
|
| with tf.io.gfile.GFile(image_path, 'rb') as f:
|
| image = Image.open(f)
|
| image = np.asarray(image.convert('RGB').resize((width, height)))
|
| image = np.expand_dims(image, 0)
|
|
|
|
|
| interpreter.set_tensor(input_details[0]['index'], image)
|
| interpreter.invoke()
|
| output_tflite = interpreter.get_tensor(output_details[0]['index'])
|
|
|
| with tf.Graph().as_default():
|
| tf.graph_util.import_graph_def(graph_def, name='')
|
| with tf.compat.v1.Session() as sess:
|
|
|
|
|
|
|
| output_graph = sess.run(
|
| FLAGS.output_tensor_name, feed_dict={'ImageTensor:0': image})
|
|
|
| print('%.2f%% pixels have matched semantic labels.' % (
|
| 100 * np.mean(output_graph == output_tflite)))
|
|
|
|
|
| def main(unused_argv):
|
| with tf.io.gfile.GFile(FLAGS.quantized_graph_def_path, 'rb') as f:
|
| graph_def = tf.compat.v1.GraphDef.FromString(f.read())
|
| tflite_model = convert_to_tflite(
|
| graph_def, FLAGS.input_tensor_name, FLAGS.output_tensor_name)
|
|
|
| if FLAGS.output_tflite_path:
|
| with tf.io.gfile.GFile(FLAGS.output_tflite_path, 'wb') as f:
|
| f.write(tflite_model)
|
|
|
| if FLAGS.test_image_path:
|
| check_tflite_consistency(graph_def, tflite_model, FLAGS.test_image_path)
|
|
|
|
|
| if __name__ == '__main__':
|
| app.run(main)
|
|
|