dgarrett-synaptics commited on
Commit
8a59108
·
verified ·
1 Parent(s): 3001092

Delete infer_code_gen.py

Browse files
Files changed (1) hide show
  1. infer_code_gen.py +0 -191
infer_code_gen.py DELETED
@@ -1,191 +0,0 @@
1
- import argparse
2
- import os
3
- import subprocess
4
- from pathlib import Path
5
- from gen_model_cpp import generate_model_cpp
6
- from gen_input_expected_data import generate_input_expected_data
7
- from generate_micro_mutable_op_resolver_from_model import generate_micro_mutable_ops_resolver_header
8
- from jinja2 import Environment, FileSystemLoader
9
- import datetime
10
- import glob
11
- import platform
12
-
13
- # Function to expand wildcards in input paths
14
- def expand_wildcards(file_paths):
15
- expanded_paths = []
16
- for path in file_paths:
17
- # Check if the path contains a wildcard
18
- if '*' in path:
19
- # Expand the wildcard to actual file names and sort them
20
- expanded_file_paths = sorted(glob.glob(path))
21
- # Extend the list with the sorted paths
22
- expanded_paths.extend(expanded_file_paths)
23
- else:
24
- # If no wildcard, add the path as is
25
- expanded_paths.append(path)
26
- return expanded_paths
27
-
28
- # Parse command line arguments
29
- parser = argparse.ArgumentParser(description='Wrapper script to run TFLite model generation scripts.')
30
- parser.add_argument('-t', '--tflite_path', type=str, help='Path to TFLite model file', required=True)
31
- parser.add_argument('-o', '--output_dir', type=str, help='Directory to output generated files', default='.')
32
- parser.add_argument('-n', '--namespace', type=str, help='Namespace to use for generated code', default='model')
33
- parser.add_argument('-s', '--script', type=str, nargs='+', choices=['model', 'inout'],
34
- help='Choose specific scripts to run, if not provided then run all scripts, separated by spaces')
35
- parser.add_argument('-i', '--input', type=str, nargs='+', help='List of input npy/bin files')
36
- parser.add_argument('-c', '--compiler', type=str, choices=['vela', 'synai', 'none'], help='Choose target compiler', default='vela')
37
- parser.add_argument('-tl', '--tflite_loc', type=int, choices=[1, 2], help="Choose an option (1: SRAM, 2: FLASH)", default=1, required=False)
38
- parser.add_argument('-p', '--optimize', type=str, choices=['Performance', 'Size'], help="Choose optimization Type", default='Performance', required=False)
39
- args = parser.parse_args()
40
-
41
- synai_ethosu_op_found = 0
42
-
43
- # Expand wildcards in input file paths
44
- if args.input:
45
- args.input = expand_wildcards(args.input)
46
-
47
- memory_mode = ''
48
-
49
- if args.tflite_loc == 1:
50
- memory_mode = '--memory-mode=Sram_Only'
51
- else:
52
- memory_mode = '--memory-mode=Shared_Sram'
53
-
54
- # Determine which scripts to run
55
- scripts_to_run = []
56
- if args.script:
57
- scripts_to_run = args.script
58
- else:
59
- scripts_to_run = ['model', 'inout']
60
-
61
- # Check if vela compilation is needed or not
62
- # If not that means the user is trying to run a non-vela model
63
- # In that case force the file name to no vela one
64
- file_name = os.path.basename(args.tflite_path)
65
- args.tflite_path = os.path.abspath(args.tflite_path)
66
-
67
- if args.compiler == 'vela':
68
- new_tflite_file_name = file_name.split('.')[0] + '_vela.tflite'
69
- elif args.compiler == 'synai':
70
- new_tflite_file_name = file_name.split('.')[0] + '_synai.tflite'
71
- elif args.compiler == 'none':
72
- new_tflite_file_name = os.path.basename(args.tflite_path)
73
- else:
74
- print("Invalid compiler option")
75
- exit(1)
76
-
77
- if platform.system() == "Windows":
78
- new_tflite_path = os.path.dirname(args.tflite_path) + "\\" + new_tflite_file_name
79
- else:
80
- new_tflite_path = os.path.dirname(args.tflite_path) + "/" + new_tflite_file_name
81
- # Get the path to the directory containing this script
82
- script_dir = Path(__file__).parent
83
-
84
- # Initialize Jinja2 environment
85
- env = Environment(loader=FileSystemLoader(script_dir / 'templates'),
86
- trim_blocks=True,
87
- lstrip_blocks=True)
88
- header_template = env.get_template("header_template.txt")
89
- license_header = header_template.render(script_name=script_dir.name,
90
- file_name=Path(args.tflite_path).name,
91
- gen_time=datetime.datetime.now(),
92
- year=datetime.datetime.now().year)
93
-
94
- if args.compiler == 'vela':
95
- # Generate vela optimized model
96
- print("************ VELA ************")
97
- if platform.system() == "Windows":
98
- vela_params = ['vela', '--output-dir', os.path.dirname(args.tflite_path), '--accelerator-config=ethos-u55-128' , '--optimise=' + args.optimize, '--config=Arm\\vela.ini', memory_mode, '--system-config=Ethos_U55_High_End_Embedded', args.tflite_path]
99
- else:
100
- vela_params = ['vela', '--output-dir', os.path.dirname(args.tflite_path), '--accelerator-config=ethos-u55-128' , '--optimise=' + args.optimize, '--config=Arm/vela.ini', memory_mode, '--system-config=Ethos_U55_High_End_Embedded', args.tflite_path]
101
- subprocess.run(vela_params)
102
- print("********* END OF VELA *********")
103
- elif args.compiler == 'synai':
104
- # Generate synai optimized model
105
- print("*********** SYNAI **********")
106
- synai_params = ['synai', '--output-dir', os.path.dirname(args.tflite_path), args.tflite_path]
107
- subprocess.run(synai_params)
108
- print("******** END OF SYNAI ********")
109
- else:
110
- print("******* No Compilation *******")
111
-
112
- # Run the selected scripts
113
- for script in scripts_to_run:
114
- if script == 'model':
115
- # Generate model C++ code
116
- generate_model_cpp(new_tflite_path, args.output_dir, args.namespace, args.tflite_loc, license_header)
117
-
118
- # Generate micro mutable op resolver code
119
- common_path = os.path.dirname(new_tflite_path)
120
- if common_path == '':
121
- common_path = '.'
122
- generate_micro_mutable_ops_resolver_header(common_path, [os.path.basename(new_tflite_path)], args.output_dir,
123
- args.namespace, license_header)
124
-
125
- # Open the source file in read mode and the destination file in append mode
126
- if platform.system() == "Windows":
127
- with open(args.output_dir + "\\" + args.namespace + '_micro_mutable_op_resolver.hpp', 'r') as source_file, \
128
- open(args.output_dir + "\\" + args.namespace + '.cc', 'a') as destination_file:
129
- # Read the content from the source file
130
- content = source_file.read()
131
- # Append the content to the destination file
132
- destination_file.write(content)
133
- else:
134
- with open(args.output_dir + "/" + args.namespace + '_micro_mutable_op_resolver.hpp', 'r') as source_file, \
135
- open(args.output_dir + "/" + args.namespace + '.cc', 'a') as destination_file:
136
- # Read the content from the source file
137
- content = source_file.read()
138
- # Append the content to the destination file
139
- destination_file.write(content)
140
-
141
- generate_micro_mutable_ops_resolver_header(common_path, [os.path.basename(args.tflite_path)], args.output_dir,
142
- "orig", license_header)
143
-
144
- if platform.system() == "Windows":
145
- with open(args.output_dir + "\\" + 'orig_micro_mutable_op_resolver.hpp', 'r') as source_file:
146
- content = source_file.read()
147
- if 'AddSynai' in content:
148
- synai_ethosu_op_found = 1
149
- elif 'AddEthosU' in content:
150
- synai_ethosu_op_found = 2
151
- else:
152
- synai_ethosu_op_found = 0
153
- else:
154
- with open(args.output_dir + "/" + 'orig_micro_mutable_op_resolver.hpp', 'r') as source_file:
155
- content = source_file.read()
156
- if 'AddSynai' in content:
157
- synai_ethosu_op_found = 1
158
- elif 'AddEthosU' in content:
159
- synai_ethosu_op_found = 2
160
- else:
161
- synai_ethosu_op_found = 0
162
-
163
- # Delete micro mutable op resolver file if it exists
164
- if platform.system() == "Windows":
165
- micro_mutable_file = args.output_dir + "\\" + args.namespace + '_micro_mutable_op_resolver.hpp'
166
- else:
167
- micro_mutable_file = args.output_dir + "/" + args.namespace + '_micro_mutable_op_resolver.hpp'
168
- if os.path.exists(micro_mutable_file):
169
- os.remove(micro_mutable_file)
170
-
171
- # Delete micro mutable op resolver file if it exists
172
- if platform.system() == "Windows":
173
- micro_mutable_file = args.output_dir + "\\" + args.namespace + '_micro_mutable_op_resolver.hpp'
174
- else:
175
- micro_mutable_file = args.output_dir + "/" + 'orig_micro_mutable_op_resolver.hpp'
176
- if os.path.exists(micro_mutable_file):
177
- os.remove(micro_mutable_file)
178
-
179
-
180
- elif script == 'inout':
181
- # Check if AddSynai or AddEthosU is present in the contents of micro mutable op resolver
182
- if synai_ethosu_op_found > 0:
183
- if synai_ethosu_op_found == 1:
184
- print("Synai custom op found in the model, skipping expected output generation")
185
- else:
186
- print("EthosU custom op found in the model, skipping expected output generation")
187
- else:
188
- if args.input:
189
- generate_input_expected_data(args.tflite_path, args.output_dir, args.namespace, license_header, args.input)
190
- else:
191
- generate_input_expected_data(args.tflite_path, args.output_dir, args.namespace, license_header)