mms-tts-rknn / convert.py
danielferr85's picture
Upload folder using huggingface_hub
676e208 verified
import sys
from rknn.api import RKNN
DEFAULT_QUANT = False
def parse_arg():
if len(sys.argv) < 3:
print("Usage: python3 {} onnx_model_path [platform] [dtype(optional)] [output_rknn_path(optional)]".format(sys.argv[0]))
print(" platform choose from [rk3562, rk3566, rk3568, rk3576, rk3588, rv1126b]")
print(" dtype choose from [fp] for [rk3562, rk3566, rk3568, rk3576, rk3588, rv1126b]")
exit(1)
model_path = sys.argv[1]
platform = sys.argv[2]
do_quant = DEFAULT_QUANT
if len(sys.argv) > 3:
model_type = sys.argv[3]
if model_type not in ['i8', 'u8', 'fp']:
print("ERROR: Invalid model type: {}".format(model_type))
exit(1)
elif model_type in ['i8', 'u8']:
do_quant = True
else:
do_quant = False
if len(sys.argv) > 4:
output_path = sys.argv[4]
else:
output_path = model_path.replace('.onnx', f"_{platform}.rknn")
return model_path, platform, do_quant, output_path
if __name__ == '__main__':
model_path, platform, do_quant, output_path = parse_arg()
# Create RKNN object
rknn = RKNN(verbose=False)
# Pre-process config
print('--> Config model')
if 'encoder' in model_path:
op_target = {'7398-rs':'cpu', '5773-rs':'cpu'}
#rknn.config(target_platform=platform, op_target=op_target)
rknn.config(target_platform=platform, model_pruning=True, optimization_level=3, single_core_mode=True, op_target=op_target)
else:
rknn.config(target_platform=platform, model_pruning=True, optimization_level=3, single_core_mode=True)
print('done')
# Load model
print('--> Loading model')
ret = rknn.load_onnx(model=model_path)
if ret != 0:
print('Load model failed!')
exit(ret)
print('done')
# Build model
print('--> Building model')
ret = rknn.build(do_quantization=do_quant)
if ret != 0:
print('Build model failed!')
exit(ret)
print('done')
# Export rknn model
print('--> Export rknn model')
ret = rknn.export_rknn(output_path)
if ret != 0:
print('Export rknn model failed!')
exit(ret)
print('done')
# Release
rknn.release()