| import os |
| import onnx |
| import shutil |
| import zipfile |
| import argparse |
|
|
|
|
| def create_triton_config(model_path, config_path, model_name, max_batch_size=0): |
| |
| model = onnx.load(model_path) |
| |
| input_tensors = [] |
| for i in model.graph.input: |
| shape = [dim.dim_value if dim.dim_value >= 1 else -1 for dim in i.type.tensor_type.shape.dim][1:] |
| input_tensors.append({"name": i.name, "data_type": "TYPE_FP32", "dims": shape}) |
| output_tensors = [] |
| for o in model.graph.output: |
| shape = [dim.dim_value if dim.dim_value >= 1 else -1 for dim in o.type.tensor_type.shape.dim] |
| |
| config = { |
| "name": model_name, |
| "backend": "onnxruntime", |
| "max_batch_size": max_batch_size, |
| "input": input_tensors, |
| "output": output_tensors, |
| "instance_group": [{"count": 1, "kind": "KIND_CPU"}], |
| } |
| |
| with open(config_path, 'w') as f: |
| f.write("name: \"" + config['name'] + "\"\n") |
| f.write("backend: \"" + config['backend'] + "\"\n") |
| f.write("max_batch_size: " + str(config['max_batch_size']) + "\n") |
| f.write("input [\n") |
| for input_tensor in config['input']: |
| f.write(" {\n") |
| f.write(" name: \"" + input_tensor['name'] + "\"\n") |
| f.write(" data_type: " + input_tensor['data_type'] + "\n") |
| f.write(" dims: [ " + ", ".join([str(dim) for dim in input_tensor['dims']]) + " ]\n") |
| f.write(" }\n") |
| f.write("]\n") |
| f.write("output [\n") |
| for output_tensor in config['output']: |
| f.write(" {\n") |
| f.write(" name: \"" + output_tensor['name'] + "\"\n") |
| f.write(" data_type: " + output_tensor['data_type'] + "\n") |
| f.write(" dims: [ " + ", ".join([str(dim) for dim in output_tensor['dims']]) + " ]\n") |
| f.write(" }\n") |
| f.write("]\n") |
| f.write("instance_group [\n") |
| for instance_group in config['instance_group']: |
| f.write(" {\n") |
| f.write(" count: " + str(instance_group['count']) + "\n") |
| f.write(" kind: " + instance_group['kind'] + "\n") |
| f.write(" }\n") |
| f.write("]\n") |
| print(f"The configuration file has been saved to '{config_path}'") |
|
|
|
|
| def list_onnx_files(directory): |
| onnx_files = [] |
| for root, _, files in os.walk(directory): |
| for file in files: |
| if file.endswith(".onnx"): |
| onnx_files.append(file) |
| return onnx_files |
|
|
|
|
| def zip_folder(folder_path, output_path): |
| |
| with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zip_file: |
| |
| for root, dirs, files in os.walk(folder_path): |
| for file in files: |
| file_path = os.path.join(root, file) |
| zip_file.write(file_path, arcname=os.path.relpath(file_path, folder_path)) |
| |
| for dir in dirs: |
| dir_path = os.path.join(root, dir) |
| zip_file.write(dir_path, arcname=os.path.relpath(dir_path, folder_path)) |
| print(f"The folder '{folder_path}' has been zipped to '{output_path}'") |
|
|
|
|
| """ |
| python model_packaging.py adaboost_regressor |
| """ |
| if __name__ == "__main__": |
| |
| parser = argparse.ArgumentParser(description='Process a directory with ONNX models.') |
| parser.add_argument('folder_path', type=str, help='Path to the directory with ONNX models.') |
| args = parser.parse_args() |
| folder_path = args.folder_path |
| |
| os.chdir(folder_path) |
| filenames = list_onnx_files(folder_path) |
| version = '1' |
| print(filenames) |
| for filename in filenames: |
| if not filename.startswith("."): |
| foldername = os.path.splitext(filename)[0] |
| if not os.path.exists(foldername): |
| os.makedirs(foldername, exist_ok=True) |
| folderdir = os.path.join(foldername, version) |
| os.makedirs(folderdir, exist_ok=True) |
| shutil.copy(filename, folderdir) |
| model_path = os.path.join(folderdir, filename) |
| config_path = os.path.join(foldername, "config.pbtxt") |
| create_triton_config(model_path, config_path, foldername, max_batch_size=0) |
| os.rename(os.path.join(folderdir, filename), os.path.join(folderdir, 'model.onnx')) |
| print(f"The file '{os.path.join(folderdir, filename)}' has been renamed to '{os.path.join(folderdir, 'model.onnx')}'") |
| print(f"{foldername} folder created successfully!") |
| print(f"{filename} copied to {foldername} successfully!") |
| zip_folder(os.path.join(folder_path, foldername), f"{foldername}.zip") |
| else: |
| print(f"{foldername} folder already exists.") |
|
|