| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Tests for input_reader_builder.""" |
|
|
| import os |
| import numpy as np |
| import tensorflow as tf |
|
|
| from google.protobuf import text_format |
|
|
| from object_detection.builders import input_reader_builder |
| from object_detection.core import standard_fields as fields |
| from object_detection.protos import input_reader_pb2 |
| from object_detection.utils import dataset_util |
|
|
|
|
| class InputReaderBuilderTest(tf.test.TestCase): |
|
|
| def create_tf_record(self): |
| path = os.path.join(self.get_temp_dir(), 'tfrecord') |
| writer = tf.python_io.TFRecordWriter(path) |
|
|
| image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8) |
| flat_mask = (4 * 5) * [1.0] |
| with self.test_session(): |
| encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval() |
| example = tf.train.Example(features=tf.train.Features(feature={ |
| 'image/encoded': dataset_util.bytes_feature(encoded_jpeg), |
| 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')), |
| 'image/height': dataset_util.int64_feature(4), |
| 'image/width': dataset_util.int64_feature(5), |
| 'image/object/bbox/xmin': dataset_util.float_list_feature([0.0]), |
| 'image/object/bbox/xmax': dataset_util.float_list_feature([1.0]), |
| 'image/object/bbox/ymin': dataset_util.float_list_feature([0.0]), |
| 'image/object/bbox/ymax': dataset_util.float_list_feature([1.0]), |
| 'image/object/class/label': dataset_util.int64_list_feature([2]), |
| 'image/object/mask': dataset_util.float_list_feature(flat_mask), |
| })) |
| writer.write(example.SerializeToString()) |
| writer.close() |
|
|
| return path |
|
|
| def test_build_tf_record_input_reader(self): |
| tf_record_path = self.create_tf_record() |
|
|
| input_reader_text_proto = """ |
| shuffle: false |
| num_readers: 1 |
| tf_record_input_reader {{ |
| input_path: '{0}' |
| }} |
| """.format(tf_record_path) |
| input_reader_proto = input_reader_pb2.InputReader() |
| text_format.Merge(input_reader_text_proto, input_reader_proto) |
| tensor_dict = input_reader_builder.build(input_reader_proto) |
|
|
| with tf.train.MonitoredSession() as sess: |
| output_dict = sess.run(tensor_dict) |
|
|
| self.assertTrue(fields.InputDataFields.groundtruth_instance_masks |
| not in output_dict) |
| self.assertEquals( |
| (4, 5, 3), output_dict[fields.InputDataFields.image].shape) |
| self.assertEquals( |
| [2], output_dict[fields.InputDataFields.groundtruth_classes]) |
| self.assertEquals( |
| (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) |
| self.assertAllEqual( |
| [0.0, 0.0, 1.0, 1.0], |
| output_dict[fields.InputDataFields.groundtruth_boxes][0]) |
|
|
| def test_build_tf_record_input_reader_and_load_instance_masks(self): |
| tf_record_path = self.create_tf_record() |
|
|
| input_reader_text_proto = """ |
| shuffle: false |
| num_readers: 1 |
| load_instance_masks: true |
| tf_record_input_reader {{ |
| input_path: '{0}' |
| }} |
| """.format(tf_record_path) |
| input_reader_proto = input_reader_pb2.InputReader() |
| text_format.Merge(input_reader_text_proto, input_reader_proto) |
| tensor_dict = input_reader_builder.build(input_reader_proto) |
|
|
| with tf.train.MonitoredSession() as sess: |
| output_dict = sess.run(tensor_dict) |
|
|
| self.assertEquals( |
| (4, 5, 3), output_dict[fields.InputDataFields.image].shape) |
| self.assertEquals( |
| [2], output_dict[fields.InputDataFields.groundtruth_classes]) |
| self.assertEquals( |
| (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) |
| self.assertAllEqual( |
| [0.0, 0.0, 1.0, 1.0], |
| output_dict[fields.InputDataFields.groundtruth_boxes][0]) |
| self.assertAllEqual( |
| (1, 4, 5), |
| output_dict[fields.InputDataFields.groundtruth_instance_masks].shape) |
|
|
| def test_raises_error_with_no_input_paths(self): |
| input_reader_text_proto = """ |
| shuffle: false |
| num_readers: 1 |
| load_instance_masks: true |
| """ |
| input_reader_proto = input_reader_pb2.InputReader() |
| text_format.Merge(input_reader_text_proto, input_reader_proto) |
| with self.assertRaises(ValueError): |
| input_reader_builder.build(input_reader_proto) |
|
|
| if __name__ == '__main__': |
| tf.test.main() |
|
|