dgarrett-synaptics commited on
Commit
993bb26
·
verified ·
1 Parent(s): 1bde287

Delete gen_model_cpp.py

Browse files
Files changed (1) hide show
  1. gen_model_cpp.py +0 -96
gen_model_cpp.py DELETED
@@ -1,96 +0,0 @@
1
-
2
- import datetime
3
- from pathlib import Path
4
- from jinja2 import Environment, FileSystemLoader
5
- import binascii
6
- import platform
7
-
8
- # Define the choices and corresponding strings for tflite location
9
- loc_choices = {
10
- 1: "MODEL_TFLITE_ATTRIBUTE", # SRAM
11
- 2: "MODEL_TFLITE_ATTRIBUTE_FLASH", # QSPI FLASH
12
- }
13
-
14
- import datetime
15
- from pathlib import Path
16
- from jinja2 import Environment, FileSystemLoader
17
-
18
-
19
- def generate_model_cpp(tflite_path, output_dir, namespace, tflite_loc, license_header):
20
- tflite_loc_choice = loc_choices.get(tflite_loc, "MODEL_TFLITE_ATTRIBUTE")
21
-
22
- # Get the path to the directory containing this script
23
- script_dir = Path(__file__).parent
24
-
25
- # Initialize Jinja2 environment
26
- env = Environment(loader=FileSystemLoader(script_dir / 'templates'),
27
- trim_blocks=True,
28
- lstrip_blocks=True)
29
-
30
- if not Path(tflite_path).is_file():
31
- raise Exception(f"{tflite_path} not found")
32
-
33
- # Resolve relative paths for output_dir and cpp_filename
34
- output_dir = Path(output_dir).resolve()
35
- cpp_filename = (output_dir / (namespace + ".cc")).resolve()
36
- if platform.system() == "Windows":
37
- print(f"++ Converting {Path(tflite_path).name} to {output_dir}\{cpp_filename.name}")
38
- else:
39
- print(f"++ Converting {Path(tflite_path).name} to {output_dir}/{cpp_filename.name}")
40
-
41
- output_dir.mkdir(exist_ok=True)
42
-
43
- model_data, model_length = get_tflite_data(tflite_path)
44
- env.get_template('tflite.cc.template').stream(common_template_header=license_header,
45
- model_data=model_data,
46
- model_length=model_length,
47
- namespace=namespace,
48
- tflite_attribute=tflite_loc_choice).dump(str(cpp_filename))
49
-
50
-
51
- def get_tflite_data(tflite_path):
52
- """
53
- Reads a binary file and returns a C style array as a
54
- list of strings.
55
-
56
- Argument:
57
- tflite_path: path to the tflite model.
58
-
59
- Returns:
60
- tuple: (list of strings, int)
61
- - List of strings representing the C style array
62
- - Number of bytes in the binary file
63
- """
64
- with open(tflite_path, 'rb') as tflite_model:
65
- data = tflite_model.read()
66
-
67
- bytes_per_line = 32
68
- hex_digits_per_line = bytes_per_line * 2
69
- hexstream = binascii.hexlify(data).decode('utf-8')
70
- hexstring = '{'
71
-
72
- for i in range(0, len(hexstream), 2):
73
- if 0 == (i % hex_digits_per_line):
74
- hexstring += "\n"
75
- hexstring += '0x' + hexstream[i:i+2] + ", "
76
-
77
- hexstring += '};\n'
78
- return [hexstring], len(data)
79
-
80
- # Optionally, you can still keep the command-line interface for standalone usage
81
- if __name__ == '__main__':
82
- import argparse
83
- parser = argparse.ArgumentParser()
84
- parser.add_argument("--tflite_path", help="Model (.tflite) path", required=True)
85
- parser.add_argument("--output_dir", help="Output directory", required=True)
86
- parser.add_argument('-e', '--expression', action='append', default=[], dest="expr")
87
- parser.add_argument('--header', action='append', default=[], dest="headers")
88
- parser.add_argument('-ns', '--namespaces', action='append', default=[], dest="namespaces")
89
- parser.add_argument("--license_template", type=str, help="Header template file",
90
- default="header_template.txt")
91
- parser.add_argument('-tl','--tflite_loc', type=int, choices=loc_choices.keys(),
92
- help="Choose an option (1 : SRAM, 2 : FLASH)", default=1, required=False)
93
-
94
- args = parser.parse_args()
95
- license_header = ""
96
- generate_model_cpp(args.tflite_path, args.output_dir, args.namespace, args.tflite_loc, license_header)