| """ | |
| Usage: | |
| python graph_optimizer.py \ | |
| --tf_path ../../tensorflow/ \ | |
| --model_folder "path_to_the_model_folder" \ | |
| --output_names "activation, accuracy" \ | |
| --input_names "x" | |
| """ | |
| import os, argparse | |
| from subprocess import call | |
| import freeze_graph | |
| import tensorflow as tf | |
| dir = os.path.dirname(os.path.realpath(__file__)) | |
| fr_name = "_frozen.pb" | |
| op_name = "_optimized.pb" | |
| def graph_freez(model_folder, output_names): | |
| print("Model folder", model_folder) | |
| checkpoint = tf.train.get_checkpoint_state(model_folder) | |
| print(checkpoint) | |
| checkpoint_path = checkpoint.model_checkpoint_path | |
| output_graph_filename = checkpoint_path + fr_name | |
| input_saver_def_path = "" | |
| input_binary = True | |
| output_node_names = output_names | |
| restore_op_name = "save/restore_all" | |
| filename_tensor_name = "save/Const:0" | |
| clear_devices = False | |
| input_meta_graph = checkpoint_path + ".meta" | |
| freeze_graph.freeze_graph( | |
| "", input_saver_def_path, input_binary, checkpoint_path, | |
| output_node_names, restore_op_name, filename_tensor_name, | |
| output_graph_filename, clear_devices, "", "", input_meta_graph) | |
| return output_graph_filename | |
| def graph_optimization(tf_path, graph_file, input_names, output_names): | |
| output_file = graph_file[:-len(fr_name)] + op_name | |
| tf_path += "bazel-bin/tensorflow/tools/graph_transforms/transform_graph" | |
| call([tf_path, | |
| "--in_graph=" + graph_file, | |
| "--out_graph=" + output_file, | |
| "--inputs=" + input_names, | |
| "--outputs=" + output_names, | |
| """--transforms= | |
| strip_unused_nodes(type=float, shape="1,299,299,3") | |
| fold_constants(ignore_errors=true) | |
| fold_batch_norms | |
| fold_old_batch_norms"""]) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser( | |
| "Script freezes graph and optimize it for mobile usage") | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| help="Path of folder + model name (folder_path/model_name)") | |
| parser.add_argument( | |
| "--input_names", | |
| type=str, | |
| default="", | |
| help="Input node names, comma separated.") | |
| parser.add_argument( | |
| "--output_names", | |
| type=str, | |
| default="", | |
| help="Output node names, comma separated.") | |
| parser.add_argument( | |
| "--tf_path", | |
| type=str, | |
| default="../../tensorflow/", | |
| help="Path to the folder with tensorflow (requires bazel build of graph_transforms)") | |
| args = parser.parse_args() | |
| graph = graph_freez(args.model, args.output_names) | |
| graph_optimization(args.tf_path, graph, args.input_names, args.output_names) | |