dgarrett-synaptics commited on
Commit
3001092
·
verified ·
1 Parent(s): 993bb26

Delete generate_micro_mutable_op_resolver_from_model.py

Browse files
generate_micro_mutable_op_resolver_from_model.py DELETED
@@ -1,178 +0,0 @@
1
- import os
2
- import re
3
- from mako.template import Template
4
- from tensorflow.lite.tools import visualize as visualize
5
- from mako import template
6
- from pathlib import Path
7
- import platform
8
-
9
- def generate_micro_mutable_ops_resolver_header(common_tflite_path, input_tflite_files, output_dir, namespace, license_header, verify_op_list_against_header=None):
10
- TEMPLATE_DIR = os.path.abspath('templates')
11
-
12
- def parse_string(word):
13
- """Converts a flatbuffer operator string to a format suitable for Micro
14
- Mutable Op Resolver. Example: CONV_2D --> AddConv2D."""
15
-
16
- # Edge case for AddDetectionPostprocess().
17
- # The custom code is TFLite_Detection_PostProcess.
18
- word = word.replace('TFLite', '')
19
-
20
- word_split = re.split('_|-', word)
21
- formated_op_string = ''
22
- for part in word_split:
23
- if len(part) > 1:
24
- if part[0].isalpha():
25
- formated_op_string += part[0].upper() + part[1:].lower()
26
- else:
27
- formated_op_string += part.upper()
28
- else:
29
- formated_op_string += part.upper()
30
- return 'Add' + formated_op_string
31
-
32
- def GetModelOperatorsAndActivation(model_path):
33
- """Extracts a set of operators from a tflite model."""
34
-
35
- custom_op_found = False
36
- operators_and_activations = set()
37
-
38
- with open(model_path, 'rb') as f:
39
- data_bytes = bytearray(f.read())
40
-
41
- data = visualize.CreateDictFromFlatbuffer(data_bytes)
42
-
43
- for op_code in data["operator_codes"]:
44
- if op_code['custom_code'] is None:
45
- op_code["builtin_code"] = max(op_code["builtin_code"],
46
- op_code["deprecated_builtin_code"])
47
- else:
48
- custom_op_found = True
49
- operators_and_activations.add(
50
- visualize.NameListToString(op_code['custom_code']))
51
-
52
- for op_code in data["operator_codes"]:
53
- # Custom operator already added.
54
- if custom_op_found and visualize.BuiltinCodeToName(
55
- op_code['builtin_code']) == "CUSTOM":
56
- continue
57
-
58
- operators_and_activations.add(
59
- visualize.BuiltinCodeToName(op_code['builtin_code']))
60
-
61
- return operators_and_activations
62
-
63
- def GenerateMicroMutableOpsResolverHeaderFile(operators, name_of_model,
64
- output_dir, namespace):
65
- """Generates Micro Mutable Op Resolver code based on a template."""
66
-
67
- number_of_ops = len(operators)
68
- outfile = 'micro_mutable_op_resolver.hpp'
69
-
70
- # Get the path to the directory containing this script
71
- script_dir = Path(__file__).parent
72
-
73
- # Construct the relative path to the template file
74
- template_file_path = script_dir / 'templates' / (outfile + '.mako')
75
-
76
- # Generate the resolver file with the template
77
- build_template = Template(filename=str(template_file_path))
78
-
79
- output_dir = Path(output_dir).resolve()
80
- if platform.system() == "Windows":
81
- output_path = str(output_dir) + "\\" + (namespace + "_" + outfile)
82
- else:
83
- output_path = str(output_dir) + "/" + (namespace + "_" + outfile)
84
-
85
- with open(output_path, 'w') as file_obj:
86
- key_values_in_template = {
87
- 'model': name_of_model,
88
- 'number_of_ops': number_of_ops,
89
- 'operators': operators,
90
- 'namespace': namespace,
91
- 'common_template_header': license_header
92
- }
93
- file_obj.write(build_template.render(**key_values_in_template))
94
-
95
- def verify_op_list(op_list, header):
96
- """
97
- Verifies that all operations in op_list are supported by TFLM, as declared in the header file.
98
-
99
- Args:
100
- op_list (list): A list of operation names to verify.
101
- header (str): Path to the header file containing declarations of supported operations.
102
-
103
- Returns:
104
- bool: True if any operation in op_list is not supported, False otherwise.
105
- """
106
- # Read the header file and extract supported operations
107
- supported_op_list = []
108
- with open(header, 'r') as f:
109
- for line in f:
110
- # Assuming the header file declares operations in the form "TfLiteStatus Add<OpName>(...);"
111
- match = re.search(r"TfLiteStatus Add(\w+)\(.*\);", line)
112
- if match:
113
- supported_op = match.group(1)
114
- supported_op_list.append(supported_op)
115
-
116
- # Check if all operations in op_list are in supported_op_list
117
- unsupported_ops = [op for op in op_list if op not in supported_op_list]
118
- if unsupported_ops:
119
- print(f"The following operations are not supported by TFLM: {', '.join(unsupported_ops)}")
120
- return True # Indicating verification failed due to unsupported operations
121
-
122
- return False # All operations are supported
123
-
124
- model_names = []
125
- final_operator_list = []
126
- merged_operator_list = []
127
-
128
- for relative_model_path in input_tflite_files:
129
- full_model_path = f"{common_tflite_path}/{relative_model_path}"
130
- operators = GetModelOperatorsAndActivation(full_model_path)
131
- model_name = os.path.basename(full_model_path)
132
- model_names.append(model_name)
133
-
134
- parsed_operator_list = [parse_string(op) for op in sorted(operators)]
135
- merged_operator_list.extend(parsed_operator_list)
136
-
137
- final_operator_list = sorted(set(merged_operator_list))
138
-
139
- if verify_op_list_against_header:
140
- if verify_op_list(final_operator_list, verify_op_list_against_header):
141
- print("Verification failed.")
142
- return
143
-
144
- os.makedirs(output_dir, exist_ok=True)
145
- GenerateMicroMutableOpsResolverHeaderFile(final_operator_list, model_name,
146
- output_dir, namespace)
147
-
148
- # Optionally, keep the command-line interface for standalone usage
149
- if __name__ == '__main__':
150
- from absl import app
151
- from absl import flags
152
-
153
- FLAGS = flags.FLAGS
154
- flags.DEFINE_string('common_tflite_path', None, 'Common path to tflite files.')
155
- flags.DEFINE_list('input_tflite_files', None, 'List of input TFLite files.')
156
- flags.DEFINE_string('output_dir', None, 'Directory to output generated files.')
157
- flags.DEFINE_string('namespace', None, 'Namespace for the generated code.')
158
- flags.DEFINE_string('license_header', None, 'License header')
159
- flags.DEFINE_string('verify_op_list_against_header', None, 'Header file to verify the operation list against.')
160
-
161
- flags.mark_flag_as_required('common_tflite_path')
162
- flags.mark_flag_as_required('input_tflite_files')
163
- flags.mark_flag_as_required('output_dir')
164
- flags.mark_flag_as_required('namespace')
165
- flags.mark_flag_as_required('license_header')
166
-
167
- def main(argv):
168
- print("generate_micro_mutable_ops_resolver_header")
169
- generate_micro_mutable_ops_resolver_header(
170
- FLAGS.common_tflite_path,
171
- FLAGS.input_tflite_files,
172
- FLAGS.output_dir,
173
- FLAGS.namespace,
174
- FLAGS.license_header,
175
- FLAGS.verify_op_list_against_header
176
- )
177
-
178
- app.run(main)