|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| r"""Tests for detection_inference.py."""
|
|
|
| import os
|
| import unittest
|
| import numpy as np
|
| from PIL import Image
|
| import six
|
| import tensorflow.compat.v1 as tf
|
| from google.protobuf import text_format
|
|
|
| from object_detection.core import standard_fields
|
| from object_detection.inference import detection_inference
|
| from object_detection.utils import dataset_util
|
| from object_detection.utils import tf_version
|
|
|
|
|
| def get_mock_tfrecord_path():
|
| return os.path.join(tf.test.get_temp_dir(), 'mock.tfrec')
|
|
|
|
|
| def create_mock_tfrecord():
|
| pil_image = Image.fromarray(np.array([[[123, 0, 0]]], dtype=np.uint8), 'RGB')
|
| image_output_stream = six.BytesIO()
|
| pil_image.save(image_output_stream, format='png')
|
| encoded_image = image_output_stream.getvalue()
|
|
|
| feature_map = {
|
| 'test_field':
|
| dataset_util.float_list_feature([1, 2, 3, 4]),
|
| standard_fields.TfExampleFields.image_encoded:
|
| dataset_util.bytes_feature(encoded_image),
|
| }
|
|
|
| tf_example = tf.train.Example(features=tf.train.Features(feature=feature_map))
|
| with tf.python_io.TFRecordWriter(get_mock_tfrecord_path()) as writer:
|
| writer.write(tf_example.SerializeToString())
|
| return encoded_image
|
|
|
|
|
| def get_mock_graph_path():
|
| return os.path.join(tf.test.get_temp_dir(), 'mock_graph.pb')
|
|
|
|
|
| def create_mock_graph():
|
| g = tf.Graph()
|
| with g.as_default():
|
| in_image_tensor = tf.placeholder(
|
| tf.uint8, shape=[1, None, None, 3], name='image_tensor')
|
| tf.constant([2.0], name='num_detections')
|
| tf.constant(
|
| [[[0, 0.8, 0.7, 1], [0.1, 0.2, 0.8, 0.9], [0.2, 0.3, 0.4, 0.5]]],
|
| name='detection_boxes')
|
| tf.constant([[0.1, 0.2, 0.3]], name='detection_scores')
|
| tf.identity(
|
| tf.constant([[1.0, 2.0, 3.0]]) *
|
| tf.reduce_sum(tf.cast(in_image_tensor, dtype=tf.float32)),
|
| name='detection_classes')
|
| graph_def = g.as_graph_def()
|
|
|
| with tf.gfile.Open(get_mock_graph_path(), 'w') as fl:
|
| fl.write(graph_def.SerializeToString())
|
|
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
|
| class InferDetectionsTests(tf.test.TestCase):
|
|
|
| def test_simple(self):
|
| create_mock_graph()
|
| encoded_image = create_mock_tfrecord()
|
|
|
| serialized_example_tensor, image_tensor = detection_inference.build_input(
|
| [get_mock_tfrecord_path()])
|
| self.assertAllEqual(image_tensor.get_shape().as_list(), [1, None, None, 3])
|
|
|
| (detected_boxes_tensor, detected_scores_tensor,
|
| detected_labels_tensor) = detection_inference.build_inference_graph(
|
| image_tensor, get_mock_graph_path())
|
|
|
| with self.test_session(use_gpu=False) as sess:
|
| sess.run(tf.global_variables_initializer())
|
| sess.run(tf.local_variables_initializer())
|
| tf.train.start_queue_runners()
|
|
|
| tf_example = detection_inference.infer_detections_and_add_to_example(
|
| serialized_example_tensor, detected_boxes_tensor,
|
| detected_scores_tensor, detected_labels_tensor, False)
|
| expected_example = tf.train.Example()
|
| text_format.Merge(r"""
|
| features {
|
| feature {
|
| key: "image/detection/bbox/ymin"
|
| value { float_list { value: [0.0, 0.1] } } }
|
| feature {
|
| key: "image/detection/bbox/xmin"
|
| value { float_list { value: [0.8, 0.2] } } }
|
| feature {
|
| key: "image/detection/bbox/ymax"
|
| value { float_list { value: [0.7, 0.8] } } }
|
| feature {
|
| key: "image/detection/bbox/xmax"
|
| value { float_list { value: [1.0, 0.9] } } }
|
| feature {
|
| key: "image/detection/label"
|
| value { int64_list { value: [123, 246] } } }
|
| feature {
|
| key: "image/detection/score"
|
| value { float_list { value: [0.1, 0.2] } } }
|
| feature {
|
| key: "test_field"
|
| value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } } }""",
|
| expected_example)
|
| expected_example.features.feature[
|
| standard_fields.TfExampleFields
|
| .image_encoded].CopyFrom(dataset_util.bytes_feature(encoded_image))
|
| self.assertProtoEquals(expected_example, tf_example)
|
|
|
| def test_discard_image(self):
|
| create_mock_graph()
|
| create_mock_tfrecord()
|
|
|
| serialized_example_tensor, image_tensor = detection_inference.build_input(
|
| [get_mock_tfrecord_path()])
|
| (detected_boxes_tensor, detected_scores_tensor,
|
| detected_labels_tensor) = detection_inference.build_inference_graph(
|
| image_tensor, get_mock_graph_path())
|
|
|
| with self.test_session(use_gpu=False) as sess:
|
| sess.run(tf.global_variables_initializer())
|
| sess.run(tf.local_variables_initializer())
|
| tf.train.start_queue_runners()
|
|
|
| tf_example = detection_inference.infer_detections_and_add_to_example(
|
| serialized_example_tensor, detected_boxes_tensor,
|
| detected_scores_tensor, detected_labels_tensor, True)
|
|
|
| self.assertProtoEquals(r"""
|
| features {
|
| feature {
|
| key: "image/detection/bbox/ymin"
|
| value { float_list { value: [0.0, 0.1] } } }
|
| feature {
|
| key: "image/detection/bbox/xmin"
|
| value { float_list { value: [0.8, 0.2] } } }
|
| feature {
|
| key: "image/detection/bbox/ymax"
|
| value { float_list { value: [0.7, 0.8] } } }
|
| feature {
|
| key: "image/detection/bbox/xmax"
|
| value { float_list { value: [1.0, 0.9] } } }
|
| feature {
|
| key: "image/detection/label"
|
| value { int64_list { value: [123, 246] } } }
|
| feature {
|
| key: "image/detection/score"
|
| value { float_list { value: [0.1, 0.2] } } }
|
| feature {
|
| key: "test_field"
|
| value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } } }
|
| """, tf_example)
|
|
|
|
|
| if __name__ == '__main__':
|
| tf.test.main()
|
|
|