|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Functions to export object detection inference graph."""
|
| import os
|
| import tempfile
|
| import tensorflow.compat.v1 as tf
|
| import tf_slim as slim
|
| from tensorflow.core.protobuf import saver_pb2
|
| from tensorflow.python.tools import freeze_graph
|
| from object_detection.builders import graph_rewriter_builder
|
| from object_detection.builders import model_builder
|
| from object_detection.core import standard_fields as fields
|
| from object_detection.data_decoders import tf_example_decoder
|
| from object_detection.utils import config_util
|
| from object_detection.utils import shape_utils
|
|
|
|
|
| try:
|
| from tensorflow.contrib import tfprof as contrib_tfprof
|
| from tensorflow.contrib.quantize.python import graph_matcher
|
| except ImportError:
|
|
|
| pass
|
|
|
|
|
| freeze_graph_with_def_protos = freeze_graph.freeze_graph_with_def_protos
|
|
|
|
|
| def parse_side_inputs(side_input_shapes_string, side_input_names_string,
|
| side_input_types_string):
|
| """Parses side input flags.
|
|
|
| Args:
|
| side_input_shapes_string: The shape of the side input tensors, provided as a
|
| comma-separated list of integers. A value of -1 is used for unknown
|
| dimensions. A `/` denotes a break, starting the shape of the next side
|
| input tensor.
|
| side_input_names_string: The names of the side input tensors, provided as a
|
| comma-separated list of strings.
|
| side_input_types_string: The type of the side input tensors, provided as a
|
| comma-separated list of types, each of `string`, `integer`, or `float`.
|
|
|
| Returns:
|
| side_input_shapes: A list of shapes.
|
| side_input_names: A list of strings.
|
| side_input_types: A list of tensorflow dtypes.
|
|
|
| """
|
| if side_input_shapes_string:
|
| side_input_shapes = []
|
| for side_input_shape_list in side_input_shapes_string.split('/'):
|
| side_input_shape = [
|
| int(dim) if dim != '-1' else None
|
| for dim in side_input_shape_list.split(',')
|
| ]
|
| side_input_shapes.append(side_input_shape)
|
| else:
|
| raise ValueError('When using side_inputs, side_input_shapes must be '
|
| 'specified in the input flags.')
|
| if side_input_names_string:
|
| side_input_names = list(side_input_names_string.split(','))
|
| else:
|
| raise ValueError('When using side_inputs, side_input_names must be '
|
| 'specified in the input flags.')
|
| if side_input_types_string:
|
| typelookup = {'float': tf.float32, 'int': tf.int32, 'string': tf.string}
|
| side_input_types = [
|
| typelookup[side_input_type]
|
| for side_input_type in side_input_types_string.split(',')
|
| ]
|
| else:
|
| raise ValueError('When using side_inputs, side_input_types must be '
|
| 'specified in the input flags.')
|
| return side_input_shapes, side_input_names, side_input_types
|
|
|
|
|
| def rewrite_nn_resize_op(is_quantized=False):
|
| """Replaces a custom nearest-neighbor resize op with the Tensorflow version.
|
|
|
| Some graphs use this custom version for TPU-compatibility.
|
|
|
| Args:
|
| is_quantized: True if the default graph is quantized.
|
| """
|
| def remove_nn():
|
| """Remove nearest neighbor upsampling structures and replace with TF op."""
|
| input_pattern = graph_matcher.OpTypePattern(
|
| 'FakeQuantWithMinMaxVars' if is_quantized else '*')
|
| stack_1_pattern = graph_matcher.OpTypePattern(
|
| 'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False)
|
| reshape_1_pattern = graph_matcher.OpTypePattern(
|
| 'Reshape', inputs=[stack_1_pattern, 'Const'], ordered_inputs=False)
|
| stack_2_pattern = graph_matcher.OpTypePattern(
|
| 'Pack',
|
| inputs=[reshape_1_pattern, reshape_1_pattern],
|
| ordered_inputs=False)
|
| reshape_2_pattern = graph_matcher.OpTypePattern(
|
| 'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False)
|
| consumer_pattern1 = graph_matcher.OpTypePattern(
|
| 'Add|AddV2|Max|Mul',
|
| inputs=[reshape_2_pattern, '*'],
|
| ordered_inputs=False)
|
| consumer_pattern2 = graph_matcher.OpTypePattern(
|
| 'StridedSlice',
|
| inputs=[reshape_2_pattern, '*', '*', '*'],
|
| ordered_inputs=False)
|
|
|
| def replace_matches(consumer_pattern):
|
| """Search for nearest neighbor pattern and replace with TF op."""
|
| match_counter = 0
|
| matcher = graph_matcher.GraphMatcher(consumer_pattern)
|
| for match in matcher.match_graph(tf.get_default_graph()):
|
| match_counter += 1
|
| projection_op = match.get_op(input_pattern)
|
| reshape_2_op = match.get_op(reshape_2_pattern)
|
| consumer_op = match.get_op(consumer_pattern)
|
| nn_resize = tf.image.resize_nearest_neighbor(
|
| projection_op.outputs[0],
|
| reshape_2_op.outputs[0].shape.dims[1:3],
|
| align_corners=False,
|
| name=os.path.split(reshape_2_op.name)[0] +
|
| '/resize_nearest_neighbor')
|
|
|
| for index, op_input in enumerate(consumer_op.inputs):
|
| if op_input == reshape_2_op.outputs[0]:
|
| consumer_op._update_input(index, nn_resize)
|
| break
|
|
|
| return match_counter
|
|
|
| match_counter = replace_matches(consumer_pattern1)
|
| match_counter += replace_matches(consumer_pattern2)
|
|
|
| tf.logging.info('Found and fixed {} matches'.format(match_counter))
|
| return match_counter
|
|
|
|
|
| total_removals = 0
|
| while remove_nn():
|
| total_removals += 1
|
|
|
| if total_removals > 4:
|
| raise ValueError('Graph removal encountered a infinite loop.')
|
|
|
|
|
| def replace_variable_values_with_moving_averages(graph,
|
| current_checkpoint_file,
|
| new_checkpoint_file,
|
| no_ema_collection=None):
|
| """Replaces variable values in the checkpoint with their moving averages.
|
|
|
| If the current checkpoint has shadow variables maintaining moving averages of
|
| the variables defined in the graph, this function generates a new checkpoint
|
| where the variables contain the values of their moving averages.
|
|
|
| Args:
|
| graph: a tf.Graph object.
|
| current_checkpoint_file: a checkpoint containing both original variables and
|
| their moving averages.
|
| new_checkpoint_file: file path to write a new checkpoint.
|
| no_ema_collection: A list of namescope substrings to match the variables
|
| to eliminate EMA.
|
| """
|
| with graph.as_default():
|
| variable_averages = tf.train.ExponentialMovingAverage(0.0)
|
| ema_variables_to_restore = variable_averages.variables_to_restore()
|
| ema_variables_to_restore = config_util.remove_unnecessary_ema(
|
| ema_variables_to_restore, no_ema_collection)
|
| with tf.Session() as sess:
|
| read_saver = tf.train.Saver(ema_variables_to_restore)
|
| read_saver.restore(sess, current_checkpoint_file)
|
| write_saver = tf.train.Saver()
|
| write_saver.save(sess, new_checkpoint_file)
|
|
|
|
|
| def _image_tensor_input_placeholder(input_shape=None):
|
| """Returns input placeholder and a 4-D uint8 image tensor."""
|
| if input_shape is None:
|
| input_shape = (None, None, None, 3)
|
| input_tensor = tf.placeholder(
|
| dtype=tf.uint8, shape=input_shape, name='image_tensor')
|
| return input_tensor, input_tensor
|
|
|
|
|
| def _side_input_tensor_placeholder(side_input_shape, side_input_name,
|
| side_input_type):
|
| """Returns side input placeholder and side input tensor."""
|
| side_input_tensor = tf.placeholder(
|
| dtype=side_input_type, shape=side_input_shape, name=side_input_name)
|
| return side_input_tensor, side_input_tensor
|
|
|
|
|
| def _tf_example_input_placeholder(input_shape=None):
|
| """Returns input that accepts a batch of strings with tf examples.
|
|
|
| Args:
|
| input_shape: the shape to resize the output decoded images to (optional).
|
|
|
| Returns:
|
| a tuple of input placeholder and the output decoded images.
|
| """
|
| batch_tf_example_placeholder = tf.placeholder(
|
| tf.string, shape=[None], name='tf_example')
|
| def decode(tf_example_string_tensor):
|
| tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
|
| tf_example_string_tensor)
|
| image_tensor = tensor_dict[fields.InputDataFields.image]
|
| if input_shape is not None:
|
| image_tensor = tf.image.resize(image_tensor, input_shape[1:3])
|
| return image_tensor
|
| return (batch_tf_example_placeholder,
|
| shape_utils.static_or_dynamic_map_fn(
|
| decode,
|
| elems=batch_tf_example_placeholder,
|
| dtype=tf.uint8,
|
| parallel_iterations=32,
|
| back_prop=False))
|
|
|
|
|
| def _encoded_image_string_tensor_input_placeholder(input_shape=None):
|
| """Returns input that accepts a batch of PNG or JPEG strings.
|
|
|
| Args:
|
| input_shape: the shape to resize the output decoded images to (optional).
|
|
|
| Returns:
|
| a tuple of input placeholder and the output decoded images.
|
| """
|
| batch_image_str_placeholder = tf.placeholder(
|
| dtype=tf.string,
|
| shape=[None],
|
| name='encoded_image_string_tensor')
|
| def decode(encoded_image_string_tensor):
|
| image_tensor = tf.image.decode_image(encoded_image_string_tensor,
|
| channels=3)
|
| image_tensor.set_shape((None, None, 3))
|
| if input_shape is not None:
|
| image_tensor = tf.image.resize(image_tensor, input_shape[1:3])
|
| return image_tensor
|
| return (batch_image_str_placeholder,
|
| tf.map_fn(
|
| decode,
|
| elems=batch_image_str_placeholder,
|
| dtype=tf.uint8,
|
| parallel_iterations=32,
|
| back_prop=False))
|
|
|
|
|
| input_placeholder_fn_map = {
|
| 'image_tensor': _image_tensor_input_placeholder,
|
| 'encoded_image_string_tensor':
|
| _encoded_image_string_tensor_input_placeholder,
|
| 'tf_example': _tf_example_input_placeholder
|
| }
|
|
|
|
|
| def add_output_tensor_nodes(postprocessed_tensors,
|
| output_collection_name='inference_op'):
|
| """Adds output nodes for detection boxes and scores.
|
|
|
| Adds the following nodes for output tensors -
|
| * num_detections: float32 tensor of shape [batch_size].
|
| * detection_boxes: float32 tensor of shape [batch_size, num_boxes, 4]
|
| containing detected boxes.
|
| * detection_scores: float32 tensor of shape [batch_size, num_boxes]
|
| containing scores for the detected boxes.
|
| * detection_multiclass_scores: (Optional) float32 tensor of shape
|
| [batch_size, num_boxes, num_classes_with_background] for containing class
|
| score distribution for detected boxes including background if any.
|
| * detection_features: (Optional) float32 tensor of shape
|
| [batch, num_boxes, roi_height, roi_width, depth]
|
| containing classifier features
|
| for each detected box
|
| * detection_classes: float32 tensor of shape [batch_size, num_boxes]
|
| containing class predictions for the detected boxes.
|
| * detection_keypoints: (Optional) float32 tensor of shape
|
| [batch_size, num_boxes, num_keypoints, 2] containing keypoints for each
|
| detection box.
|
| * detection_masks: (Optional) float32 tensor of shape
|
| [batch_size, num_boxes, mask_height, mask_width] containing masks for each
|
| detection box.
|
|
|
| Args:
|
| postprocessed_tensors: a dictionary containing the following fields
|
| 'detection_boxes': [batch, max_detections, 4]
|
| 'detection_scores': [batch, max_detections]
|
| 'detection_multiclass_scores': [batch, max_detections,
|
| num_classes_with_background]
|
| 'detection_features': [batch, num_boxes, roi_height, roi_width, depth]
|
| 'detection_classes': [batch, max_detections]
|
| 'detection_masks': [batch, max_detections, mask_height, mask_width]
|
| (optional).
|
| 'detection_keypoints': [batch, max_detections, num_keypoints, 2]
|
| (optional).
|
| 'num_detections': [batch]
|
| output_collection_name: Name of collection to add output tensors to.
|
|
|
| Returns:
|
| A tensor dict containing the added output tensor nodes.
|
| """
|
| detection_fields = fields.DetectionResultFields
|
| label_id_offset = 1
|
| boxes = postprocessed_tensors.get(detection_fields.detection_boxes)
|
| scores = postprocessed_tensors.get(detection_fields.detection_scores)
|
| multiclass_scores = postprocessed_tensors.get(
|
| detection_fields.detection_multiclass_scores)
|
| box_classifier_features = postprocessed_tensors.get(
|
| detection_fields.detection_features)
|
| raw_boxes = postprocessed_tensors.get(detection_fields.raw_detection_boxes)
|
| raw_scores = postprocessed_tensors.get(detection_fields.raw_detection_scores)
|
| classes = postprocessed_tensors.get(
|
| detection_fields.detection_classes) + label_id_offset
|
| keypoints = postprocessed_tensors.get(detection_fields.detection_keypoints)
|
| masks = postprocessed_tensors.get(detection_fields.detection_masks)
|
| num_detections = postprocessed_tensors.get(detection_fields.num_detections)
|
| outputs = {}
|
| outputs[detection_fields.detection_boxes] = tf.identity(
|
| boxes, name=detection_fields.detection_boxes)
|
| outputs[detection_fields.detection_scores] = tf.identity(
|
| scores, name=detection_fields.detection_scores)
|
| if multiclass_scores is not None:
|
| outputs[detection_fields.detection_multiclass_scores] = tf.identity(
|
| multiclass_scores, name=detection_fields.detection_multiclass_scores)
|
| if box_classifier_features is not None:
|
| outputs[detection_fields.detection_features] = tf.identity(
|
| box_classifier_features,
|
| name=detection_fields.detection_features)
|
| outputs[detection_fields.detection_classes] = tf.identity(
|
| classes, name=detection_fields.detection_classes)
|
| outputs[detection_fields.num_detections] = tf.identity(
|
| num_detections, name=detection_fields.num_detections)
|
| if raw_boxes is not None:
|
| outputs[detection_fields.raw_detection_boxes] = tf.identity(
|
| raw_boxes, name=detection_fields.raw_detection_boxes)
|
| if raw_scores is not None:
|
| outputs[detection_fields.raw_detection_scores] = tf.identity(
|
| raw_scores, name=detection_fields.raw_detection_scores)
|
| if keypoints is not None:
|
| outputs[detection_fields.detection_keypoints] = tf.identity(
|
| keypoints, name=detection_fields.detection_keypoints)
|
| if masks is not None:
|
| outputs[detection_fields.detection_masks] = tf.identity(
|
| masks, name=detection_fields.detection_masks)
|
| for output_key in outputs:
|
| tf.add_to_collection(output_collection_name, outputs[output_key])
|
|
|
| return outputs
|
|
|
|
|
| def write_saved_model(saved_model_path,
|
| frozen_graph_def,
|
| inputs,
|
| outputs):
|
| """Writes SavedModel to disk.
|
|
|
| If checkpoint_path is not None bakes the weights into the graph thereby
|
| eliminating the need of checkpoint files during inference. If the model
|
| was trained with moving averages, setting use_moving_averages to true
|
| restores the moving averages, otherwise the original set of variables
|
| is restored.
|
|
|
| Args:
|
| saved_model_path: Path to write SavedModel.
|
| frozen_graph_def: tf.GraphDef holding frozen graph.
|
| inputs: A tensor dictionary containing the inputs to a DetectionModel.
|
| outputs: A tensor dictionary containing the outputs of a DetectionModel.
|
| """
|
| with tf.Graph().as_default():
|
| with tf.Session() as sess:
|
|
|
| tf.import_graph_def(frozen_graph_def, name='')
|
|
|
| builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)
|
|
|
| tensor_info_inputs = {}
|
| if isinstance(inputs, dict):
|
| for k, v in inputs.items():
|
| tensor_info_inputs[k] = tf.saved_model.utils.build_tensor_info(v)
|
| else:
|
| tensor_info_inputs['inputs'] = tf.saved_model.utils.build_tensor_info(
|
| inputs)
|
| tensor_info_outputs = {}
|
| for k, v in outputs.items():
|
| tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)
|
|
|
| detection_signature = (
|
| tf.saved_model.signature_def_utils.build_signature_def(
|
| inputs=tensor_info_inputs,
|
| outputs=tensor_info_outputs,
|
| method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
|
| ))
|
|
|
| builder.add_meta_graph_and_variables(
|
| sess,
|
| [tf.saved_model.tag_constants.SERVING],
|
| signature_def_map={
|
| tf.saved_model.signature_constants
|
| .DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
| detection_signature,
|
| },
|
| )
|
| builder.save()
|
|
|
|
|
| def write_graph_and_checkpoint(inference_graph_def,
|
| model_path,
|
| input_saver_def,
|
| trained_checkpoint_prefix):
|
| """Writes the graph and the checkpoint into disk."""
|
| for node in inference_graph_def.node:
|
| node.device = ''
|
| with tf.Graph().as_default():
|
| tf.import_graph_def(inference_graph_def, name='')
|
| with tf.Session() as sess:
|
| saver = tf.train.Saver(
|
| saver_def=input_saver_def, save_relative_paths=True)
|
| saver.restore(sess, trained_checkpoint_prefix)
|
| saver.save(sess, model_path)
|
|
|
|
|
| def _get_outputs_from_inputs(input_tensors, detection_model,
|
| output_collection_name, **side_inputs):
|
| inputs = tf.cast(input_tensors, dtype=tf.float32)
|
| preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)
|
| output_tensors = detection_model.predict(
|
| preprocessed_inputs, true_image_shapes, **side_inputs)
|
| postprocessed_tensors = detection_model.postprocess(
|
| output_tensors, true_image_shapes)
|
| return add_output_tensor_nodes(postprocessed_tensors,
|
| output_collection_name)
|
|
|
|
|
| def build_detection_graph(input_type, detection_model, input_shape,
|
| output_collection_name, graph_hook_fn,
|
| use_side_inputs=False, side_input_shapes=None,
|
| side_input_names=None, side_input_types=None):
|
| """Build the detection graph."""
|
| if input_type not in input_placeholder_fn_map:
|
| raise ValueError('Unknown input type: {}'.format(input_type))
|
| placeholder_args = {}
|
| side_inputs = {}
|
| if input_shape is not None:
|
| if (input_type != 'image_tensor' and
|
| input_type != 'encoded_image_string_tensor' and
|
| input_type != 'tf_example' and
|
| input_type != 'tf_sequence_example'):
|
| raise ValueError('Can only specify input shape for `image_tensor`, '
|
| '`encoded_image_string_tensor`, `tf_example`, '
|
| ' or `tf_sequence_example` inputs.')
|
| placeholder_args['input_shape'] = input_shape
|
| placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
|
| **placeholder_args)
|
| placeholder_tensors = {'inputs': placeholder_tensor}
|
| if use_side_inputs:
|
| for idx, side_input_name in enumerate(side_input_names):
|
| side_input_placeholder, side_input = _side_input_tensor_placeholder(
|
| side_input_shapes[idx], side_input_name, side_input_types[idx])
|
| print(side_input)
|
| side_inputs[side_input_name] = side_input
|
| placeholder_tensors[side_input_name] = side_input_placeholder
|
| outputs = _get_outputs_from_inputs(
|
| input_tensors=input_tensors,
|
| detection_model=detection_model,
|
| output_collection_name=output_collection_name,
|
| **side_inputs)
|
|
|
|
|
| slim.get_or_create_global_step()
|
|
|
| if graph_hook_fn: graph_hook_fn()
|
|
|
| return outputs, placeholder_tensors
|
|
|
|
|
| def _export_inference_graph(input_type,
|
| detection_model,
|
| use_moving_averages,
|
| trained_checkpoint_prefix,
|
| output_directory,
|
| additional_output_tensor_names=None,
|
| input_shape=None,
|
| output_collection_name='inference_op',
|
| graph_hook_fn=None,
|
| write_inference_graph=False,
|
| temp_checkpoint_prefix='',
|
| use_side_inputs=False,
|
| side_input_shapes=None,
|
| side_input_names=None,
|
| side_input_types=None):
|
| """Export helper."""
|
| tf.gfile.MakeDirs(output_directory)
|
| frozen_graph_path = os.path.join(output_directory,
|
| 'frozen_inference_graph.pb')
|
| saved_model_path = os.path.join(output_directory, 'saved_model')
|
| model_path = os.path.join(output_directory, 'model.ckpt')
|
|
|
| outputs, placeholder_tensor_dict = build_detection_graph(
|
| input_type=input_type,
|
| detection_model=detection_model,
|
| input_shape=input_shape,
|
| output_collection_name=output_collection_name,
|
| graph_hook_fn=graph_hook_fn,
|
| use_side_inputs=use_side_inputs,
|
| side_input_shapes=side_input_shapes,
|
| side_input_names=side_input_names,
|
| side_input_types=side_input_types)
|
|
|
| profile_inference_graph(tf.get_default_graph())
|
| saver_kwargs = {}
|
| if use_moving_averages:
|
| if not temp_checkpoint_prefix:
|
|
|
| if os.path.isfile(trained_checkpoint_prefix):
|
| saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
|
| temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
|
| else:
|
| temp_checkpoint_prefix = tempfile.mkdtemp()
|
| replace_variable_values_with_moving_averages(
|
| tf.get_default_graph(), trained_checkpoint_prefix,
|
| temp_checkpoint_prefix)
|
| checkpoint_to_use = temp_checkpoint_prefix
|
| else:
|
| checkpoint_to_use = trained_checkpoint_prefix
|
|
|
| saver = tf.train.Saver(**saver_kwargs)
|
| input_saver_def = saver.as_saver_def()
|
|
|
| write_graph_and_checkpoint(
|
| inference_graph_def=tf.get_default_graph().as_graph_def(),
|
| model_path=model_path,
|
| input_saver_def=input_saver_def,
|
| trained_checkpoint_prefix=checkpoint_to_use)
|
| if write_inference_graph:
|
| inference_graph_def = tf.get_default_graph().as_graph_def()
|
| inference_graph_path = os.path.join(output_directory,
|
| 'inference_graph.pbtxt')
|
| for node in inference_graph_def.node:
|
| node.device = ''
|
| with tf.gfile.GFile(inference_graph_path, 'wb') as f:
|
| f.write(str(inference_graph_def))
|
|
|
| if additional_output_tensor_names is not None:
|
| output_node_names = ','.join(list(outputs.keys())+(
|
| additional_output_tensor_names))
|
| else:
|
| output_node_names = ','.join(outputs.keys())
|
|
|
| frozen_graph_def = freeze_graph.freeze_graph_with_def_protos(
|
| input_graph_def=tf.get_default_graph().as_graph_def(),
|
| input_saver_def=input_saver_def,
|
| input_checkpoint=checkpoint_to_use,
|
| output_node_names=output_node_names,
|
| restore_op_name='save/restore_all',
|
| filename_tensor_name='save/Const:0',
|
| output_graph=frozen_graph_path,
|
| clear_devices=True,
|
| initializer_nodes='')
|
|
|
| write_saved_model(saved_model_path, frozen_graph_def,
|
| placeholder_tensor_dict, outputs)
|
|
|
|
|
| def export_inference_graph(input_type,
|
| pipeline_config,
|
| trained_checkpoint_prefix,
|
| output_directory,
|
| input_shape=None,
|
| output_collection_name='inference_op',
|
| additional_output_tensor_names=None,
|
| write_inference_graph=False,
|
| use_side_inputs=False,
|
| side_input_shapes=None,
|
| side_input_names=None,
|
| side_input_types=None):
|
| """Exports inference graph for the model specified in the pipeline config.
|
|
|
| Args:
|
| input_type: Type of input for the graph. Can be one of ['image_tensor',
|
| 'encoded_image_string_tensor', 'tf_example'].
|
| pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
|
| trained_checkpoint_prefix: Path to the trained checkpoint file.
|
| output_directory: Path to write outputs.
|
| input_shape: Sets a fixed shape for an `image_tensor` input. If not
|
| specified, will default to [None, None, None, 3].
|
| output_collection_name: Name of collection to add output tensors to.
|
| If None, does not add output tensors to a collection.
|
| additional_output_tensor_names: list of additional output
|
| tensors to include in the frozen graph.
|
| write_inference_graph: If true, writes inference graph to disk.
|
| use_side_inputs: If True, the model requires side_inputs.
|
| side_input_shapes: List of shapes of the side input tensors,
|
| required if use_side_inputs is True.
|
| side_input_names: List of names of the side input tensors,
|
| required if use_side_inputs is True.
|
| side_input_types: List of types of the side input tensors,
|
| required if use_side_inputs is True.
|
| """
|
| detection_model = model_builder.build(pipeline_config.model,
|
| is_training=False)
|
| graph_rewriter_fn = None
|
| if pipeline_config.HasField('graph_rewriter'):
|
| graph_rewriter_config = pipeline_config.graph_rewriter
|
| graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config,
|
| is_training=False)
|
| _export_inference_graph(
|
| input_type,
|
| detection_model,
|
| pipeline_config.eval_config.use_moving_averages,
|
| trained_checkpoint_prefix,
|
| output_directory,
|
| additional_output_tensor_names,
|
| input_shape,
|
| output_collection_name,
|
| graph_hook_fn=graph_rewriter_fn,
|
| write_inference_graph=write_inference_graph,
|
| use_side_inputs=use_side_inputs,
|
| side_input_shapes=side_input_shapes,
|
| side_input_names=side_input_names,
|
| side_input_types=side_input_types)
|
| pipeline_config.eval_config.use_moving_averages = False
|
| config_util.save_pipeline_config(pipeline_config, output_directory)
|
|
|
|
|
| def profile_inference_graph(graph):
|
| """Profiles the inference graph.
|
|
|
| Prints model parameters and computation FLOPs given an inference graph.
|
| BatchNorms are excluded from the parameter count due to the fact that
|
| BatchNorms are usually folded. BatchNorm, Initializer, Regularizer
|
| and BiasAdd are not considered in FLOP count.
|
|
|
| Args:
|
| graph: the inference graph.
|
| """
|
| tfprof_vars_option = (
|
| contrib_tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
|
| tfprof_flops_option = contrib_tfprof.model_analyzer.FLOAT_OPS_OPTIONS
|
|
|
|
|
| tfprof_vars_option['trim_name_regexes'] = ['.*BatchNorm.*']
|
|
|
| tfprof_flops_option['trim_name_regexes'] = [
|
| '.*BatchNorm.*', '.*Initializer.*', '.*Regularizer.*', '.*BiasAdd.*'
|
| ]
|
|
|
| contrib_tfprof.model_analyzer.print_model_analysis(
|
| graph, tfprof_options=tfprof_vars_option)
|
|
|
| contrib_tfprof.model_analyzer.print_model_analysis(
|
| graph, tfprof_options=tfprof_flops_option)
|
|
|