|
|
|
|
| import os
|
| import demo_inference
|
| import tensorflow as tf
|
| from tensorflow.python.training import monitored_session
|
| from tensorflow.compat.v1 import flags
|
|
|
| _CHECKPOINT = 'model.ckpt-399731'
|
| _CHECKPOINT_URL = 'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz'
|
|
|
|
|
| class DemoInferenceTest(tf.test.TestCase):
|
| def setUp(self):
|
| super(DemoInferenceTest, self).setUp()
|
| for suffix in ['.meta', '.index', '.data-00000-of-00001']:
|
| filename = _CHECKPOINT + suffix
|
| self.assertTrue(tf.io.gfile.exists(filename),
|
| msg='Missing checkpoint file %s. '
|
| 'Please download and extract it from %s' %
|
| (filename, _CHECKPOINT_URL))
|
| self._batch_size = 32
|
| flags.FLAGS.dataset_dir = os.path.join(
|
| os.path.dirname(__file__), 'datasets/testdata/fsns')
|
|
|
| def test_moving_variables_properly_loaded_from_a_checkpoint(self):
|
| batch_size = 32
|
| dataset_name = 'fsns'
|
| images_placeholder, endpoints = demo_inference.create_model(batch_size,
|
| dataset_name)
|
| image_path_pattern = 'testdata/fsns_train_%02d.png'
|
| images_data = demo_inference.load_images(image_path_pattern, batch_size,
|
| dataset_name)
|
| tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
|
| moving_mean_tf = tf.compat.v1.get_default_graph().get_tensor_by_name(
|
| tensor_name + ':0')
|
| reader = tf.compat.v1.train.NewCheckpointReader(_CHECKPOINT)
|
| moving_mean_expected = reader.get_tensor(tensor_name)
|
|
|
| session_creator = monitored_session.ChiefSessionCreator(
|
| checkpoint_filename_with_path=_CHECKPOINT)
|
| with monitored_session.MonitoredSession(
|
| session_creator=session_creator) as sess:
|
| moving_mean_np = sess.run(moving_mean_tf,
|
| feed_dict={images_placeholder: images_data})
|
|
|
| self.assertAllEqual(moving_mean_expected, moving_mean_np)
|
|
|
| def test_correct_results_on_test_data(self):
|
| image_path_pattern = 'testdata/fsns_train_%02d.png'
|
| predictions = demo_inference.run(_CHECKPOINT, self._batch_size,
|
| 'fsns',
|
| image_path_pattern)
|
| self.assertEqual([
|
| u'Boulevard de Lunelβββββββββββββββββββ',
|
| 'Rue de Provenceββββββββββββββββββββββ',
|
| 'Rue de Port Mariaββββββββββββββββββββ',
|
| 'Avenue Charles Gounodββββββββββββββββ',
|
| 'Rue de lβAuroreββββββββββββββββββββββ',
|
| 'Rue de Beuzevilleββββββββββββββββββββ',
|
| 'Rue dβOrbeyββββββββββββββββββββββββββ',
|
| 'Rue Victor Schoulcherββββββββββββββββ',
|
| 'Rue de la Gareβββββββββββββββββββββββ',
|
| 'Rue des Tulipesββββββββββββββββββββββ',
|
| 'Rue AndrΓ© Maginotββββββββββββββββββββ',
|
| 'Route de Pringyββββββββββββββββββββββ',
|
| 'Rue des Landellesββββββββββββββββββββ',
|
| 'Rue des Ilettesββββββββββββββββββββββ',
|
| 'Avenue de Maurinβββββββββββββββββββββ',
|
| 'Rue ThΓ©resaββββββββββββββββββββββββββ',
|
| 'Route de la Balmeββββββββββββββββββββ',
|
| 'Rue HΓ©lΓ¨ne Roedererββββββββββββββββββ',
|
| 'Rue Emile Bernardββββββββββββββββββββ',
|
| 'Place de la Mairieβββββββββββββββββββ',
|
| 'Rue des Perrotsββββββββββββββββββββββ',
|
| 'Rue de la LibΓ©rationβββββββββββββββββ',
|
| 'Impasse du Capcirββββββββββββββββββββ',
|
| 'Avenue de la Grand Mareββββββββββββββ',
|
| 'Rue Pierre Brossoletteβββββββββββββββ',
|
| 'Rue de Provenceββββββββββββββββββββββ',
|
| 'Rue du Docteur Mourreββββββββββββββββ',
|
| 'Rue dβOrtheuilβββββββββββββββββββββββ',
|
| 'Rue des Sarmentsβββββββββββββββββββββ',
|
| 'Rue du Centreββββββββββββββββββββββββ',
|
| 'Impasse Pierre Mourguesββββββββββββββ',
|
| 'Rue Marcel Dassaultββββββββββββββββββ'
|
| ], predictions)
|
|
|
|
|
| if __name__ == '__main__':
|
| tf.test.main()
|
|
|