#!/usr/bin/env python # coding: utf-8 import datetime import argparse from rknn.api import RKNN from sys import exit # 模型配置 MODELS = { 'tone_clone': 'tone_clone_model.onnx', 'tone_color_extract': 'tone_color_extract_model.onnx', } TARGET_AUDIO_LENS = [1024] SOURCE_AUDIO_LENS = [1024] AUDIO_DIM = 513 QUANTIZE=False detailed_performance_log = True def convert_model(model_type): """转换指定类型的模型到RKNN格式""" if model_type not in MODELS: print(f"错误: 不支持的模型类型 {model_type}") return False onnx_model = MODELS[model_type] rknn_model = onnx_model.replace(".onnx",".rknn") if model_type == 'tone_clone': shapes = [ [ [1, 513, target_audio_len], # audio [1], # audio_length [1, 256, 1], # src_tone [1, 256, 1], # dest_tone [1], # tau ] for target_audio_len in TARGET_AUDIO_LENS ] elif model_type == 'tone_color_extract': shapes = [ [ [1, source_audio_len, 513], # audio ] for source_audio_len in SOURCE_AUDIO_LENS ] # shapes = None timedate_iso = datetime.datetime.now().isoformat() rknn = RKNN(verbose=True) rknn.config( quantized_dtype='w8a8', quantized_algorithm='normal', quantized_method='channel', quantized_hybrid_level=0, target_platform='rk3588', quant_img_RGB2BGR = False, float_dtype='float16', optimization_level=3, custom_string=f"converted by: qq: 232004040, email: 2302004040@qq.com at {timedate_iso}", remove_weight=False, compress_weight=False, inputs_yuv_fmt=None, single_core_mode=False, dynamic_input=shapes, model_pruning=False, op_target=None, quantize_weight=False, remove_reshape=False, sparse_infer=False, enable_flash_attention=False, # disable_rules=['convert_gemm_by_exmatmul'] ) print(f"开始转换 {model_type} 模型...") ret = rknn.load_onnx(model=onnx_model) if ret != 0: print("加载ONNX模型失败") return False ret = rknn.build(do_quantization=False, rknn_batch_size=None) if ret != 0: print("构建RKNN模型失败") return False ret = rknn.export_rknn(rknn_model) if ret != 0: print("导出RKNN模型失败") return False print(f"成功转换模型: {rknn_model}") return True def main(): parser = argparse.ArgumentParser(description='转换ONNX模型到RKNN格式') parser.add_argument('model_type', nargs='?', default='all', choices=['all', 'tone_clone', 'tone_color_extract'], help='要转换的模型类型 (默认: all)') args = parser.parse_args() if args.model_type == 'all': # 转换所有模型 for model_type in MODELS.keys(): if not convert_model(model_type): print(f"转换 {model_type} 失败") else: # 转换指定模型 if not convert_model(args.model_type): print(f"转换 {args.model_type} 失败") if __name__ == '__main__': main()