|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Test a tflite model using random input data."""
|
|
|
| from __future__ import print_function
|
| from absl import flags
|
| import numpy as np
|
| import tensorflow.compat.v1 as tf
|
|
|
| flags.DEFINE_string('model_path', None, 'Path to model.')
|
| FLAGS = flags.FLAGS
|
|
|
|
|
| def main(_):
|
|
|
| flags.mark_flag_as_required('model_path')
|
|
|
|
|
| interpreter = tf.lite.Interpreter(model_path=FLAGS.model_path)
|
| interpreter.allocate_tensors()
|
|
|
|
|
| input_details = interpreter.get_input_details()
|
| print('input_details:', input_details)
|
| output_details = interpreter.get_output_details()
|
| print('output_details:', output_details)
|
|
|
|
|
| input_shape = input_details[0]['shape']
|
|
|
| input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
|
| interpreter.set_tensor(input_details[0]['index'], input_data)
|
|
|
| interpreter.invoke()
|
| output_data = interpreter.get_tensor(output_details[0]['index'])
|
| print(output_data)
|
|
|
|
|
| if __name__ == '__main__':
|
| tf.app.run()
|
|
|