|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Tests for model_export."""
|
| import os
|
|
|
| import numpy as np
|
| from absl.testing import flagsaver
|
| import tensorflow as tf
|
| from tensorflow.compat.v1 import flags
|
|
|
| import common_flags
|
| import model_export
|
|
|
| _CHECKPOINT = 'model.ckpt-399731'
|
| _CHECKPOINT_URL = (
|
| 'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz')
|
|
|
|
|
| def _clean_up():
|
| tf.io.gfile.rmtree(tf.compat.v1.test.get_temp_dir())
|
|
|
|
|
| def _create_tf_example_string(image):
|
| """Create a serialized tf.Example proto for feeding the model."""
|
| example = tf.train.Example()
|
| example.features.feature['image/encoded'].float_list.value.extend(
|
| list(np.reshape(image, (-1))))
|
| return example.SerializeToString()
|
|
|
|
|
| class AttentionOcrExportTest(tf.test.TestCase):
|
| """Tests for model_export.export_model."""
|
|
|
| def setUp(self):
|
| for suffix in ['.meta', '.index', '.data-00000-of-00001']:
|
| filename = _CHECKPOINT + suffix
|
| self.assertTrue(
|
| tf.io.gfile.exists(filename),
|
| msg='Missing checkpoint file %s. '
|
| 'Please download and extract it from %s' %
|
| (filename, _CHECKPOINT_URL))
|
| flags.FLAGS.dataset_name = 'fsns'
|
| flags.FLAGS.checkpoint = _CHECKPOINT
|
| flags.FLAGS.dataset_dir = os.path.join(
|
| os.path.dirname(__file__), 'datasets/testdata/fsns')
|
| tf.test.TestCase.setUp(self)
|
| _clean_up()
|
| self.export_dir = os.path.join(
|
| tf.compat.v1.test.get_temp_dir(), 'exported_model')
|
| self.minimal_output_signature = {
|
| 'predictions': 'AttentionOcr_v1/predicted_chars:0',
|
| 'scores': 'AttentionOcr_v1/predicted_scores:0',
|
| 'predicted_length': 'AttentionOcr_v1/predicted_length:0',
|
| 'predicted_text': 'AttentionOcr_v1/predicted_text:0',
|
| 'predicted_conf': 'AttentionOcr_v1/predicted_conf:0',
|
| 'normalized_seq_conf': 'AttentionOcr_v1/normalized_seq_conf:0'
|
| }
|
|
|
| def create_input_feed(self, graph_def, serving):
|
| """Returns the input feed for the model.
|
|
|
| Creates random images, according to the size specified by dataset_name,
|
| format it in the correct way depending on whether the model was exported
|
| for serving, and return the correctly keyed feed_dict for inference.
|
|
|
| Args:
|
| graph_def: Graph definition of the loaded model.
|
| serving: Whether the model was exported for Serving.
|
|
|
| Returns:
|
| The feed_dict suitable for model inference.
|
| """
|
|
|
| self.dataset = common_flags.create_dataset('test')
|
|
|
| self.images = {
|
| 'img1':
|
| np.random.uniform(low=64, high=192,
|
| size=self.dataset.image_shape).astype('uint8'),
|
| 'img2':
|
| np.random.uniform(low=32, high=224,
|
| size=self.dataset.image_shape).astype('uint8'),
|
| }
|
| signature_def = graph_def.signature_def[
|
| tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
| if serving:
|
| input_name = signature_def.inputs[
|
| tf.saved_model.CLASSIFY_INPUTS].name
|
|
|
| feed_dict = {
|
| input_name: [
|
| _create_tf_example_string(self.images['img1']),
|
| _create_tf_example_string(self.images['img2'])
|
| ]
|
| }
|
| else:
|
| input_name = signature_def.inputs['images'].name
|
|
|
| feed_dict = {
|
| input_name: np.stack([self.images['img1'], self.images['img2']])
|
| }
|
| return feed_dict
|
|
|
| def verify_export_load_and_inference(self, export_for_serving=False):
|
| """Verify exported model can be loaded and inference can run successfully.
|
|
|
| This function will load the exported model in self.export_dir, then create
|
| some fake images according to the specification of FLAGS.dataset_name.
|
| It then feeds the input through the model, and verify the minimal set of
|
| output signatures are present.
|
| Note: Model and dataset creation in the underlying library depends on the
|
| following commandline flags:
|
| FLAGS.dataset_name
|
| Args:
|
| export_for_serving: True if the model was exported for Serving. This
|
| affects how input is fed into the model.
|
| """
|
| tf.compat.v1.reset_default_graph()
|
| sess = tf.compat.v1.Session()
|
| graph_def = tf.compat.v1.saved_model.loader.load(
|
| sess=sess,
|
| tags=[tf.saved_model.SERVING],
|
| export_dir=self.export_dir)
|
| feed_dict = self.create_input_feed(graph_def, export_for_serving)
|
| results = sess.run(self.minimal_output_signature, feed_dict=feed_dict)
|
|
|
| out_shape = (2,)
|
| self.assertEqual(np.shape(results['predicted_conf']), out_shape)
|
| self.assertEqual(np.shape(results['predicted_text']), out_shape)
|
| self.assertEqual(np.shape(results['predicted_length']), out_shape)
|
| self.assertEqual(np.shape(results['normalized_seq_conf']), out_shape)
|
| out_shape = (2, self.dataset.max_sequence_length)
|
| self.assertEqual(np.shape(results['scores']), out_shape)
|
| self.assertEqual(np.shape(results['predictions']), out_shape)
|
|
|
| @flagsaver.flagsaver
|
| def test_fsns_export_for_serving_and_load_inference(self):
|
| model_export.export_model(self.export_dir, True)
|
| self.verify_export_load_and_inference(True)
|
|
|
| @flagsaver.flagsaver
|
| def test_fsns_export_and_load_inference(self):
|
| model_export.export_model(self.export_dir, False, batch_size=2)
|
| self.verify_export_load_and_inference(False)
|
|
|
|
|
| if __name__ == '__main__':
|
| tf.test.main()
|
|
|