File size: 2,656 Bytes
f915f2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Usage:
python graph_optimizer.py \
--tf_path ../../tensorflow/ \
--model_folder "path_to_the_model_folder" \
--output_names "activation, accuracy" \
--input_names "x"
"""

import os, argparse
from subprocess import call

import freeze_graph
import tensorflow as tf

dir = os.path.dirname(os.path.realpath(__file__))

fr_name = "_frozen.pb"
op_name = "_optimized.pb"


def graph_freez(model_folder, output_names):
    print("Model folder", model_folder)
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    print(checkpoint)
    checkpoint_path = checkpoint.model_checkpoint_path
    output_graph_filename = checkpoint_path + fr_name

    input_saver_def_path = ""
    input_binary = True
    output_node_names = output_names
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    clear_devices = False
    input_meta_graph = checkpoint_path + ".meta"

    freeze_graph.freeze_graph(
        "", input_saver_def_path, input_binary, checkpoint_path,
        output_node_names, restore_op_name, filename_tensor_name,
        output_graph_filename, clear_devices, "", "", input_meta_graph)

    return output_graph_filename


def graph_optimization(tf_path, graph_file, input_names, output_names):
    output_file = graph_file[:-len(fr_name)] + op_name
    tf_path += "bazel-bin/tensorflow/tools/graph_transforms/transform_graph"

    call([tf_path,
          "--in_graph=" + graph_file,
          "--out_graph=" + output_file,
          "--inputs=" + input_names,
          "--outputs=" + output_names,
          """--transforms=
          strip_unused_nodes(type=float, shape="1,299,299,3")
          fold_constants(ignore_errors=true)
          fold_batch_norms
          fold_old_batch_norms"""])


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
            "Script freezes graph and optimize it for mobile usage")
    parser.add_argument(
        "--model",
        type=str,
        help="Path of folder + model name (folder_path/model_name)")
    parser.add_argument(
        "--input_names",
        type=str,
        default="",
        help="Input node names, comma separated.")
    parser.add_argument(
        "--output_names",
        type=str,
        default="",
        help="Output node names, comma separated.")
    parser.add_argument(
        "--tf_path",
        type=str,
        default="../../tensorflow/",
        help="Path to the folder with tensorflow (requires bazel build of graph_transforms)")

    args = parser.parse_args()

    graph = graph_freez(args.model, args.output_names)
    graph_optimization(args.tf_path, graph, args.input_names, args.output_names)