|
|
import onnx |
|
|
from onnx import version_converter, helper |
|
|
|
|
|
def convert_opset(input_model_path: str, output_model_path: str, target_opset: int): |
|
|
""" |
|
|
使用 onnx.version_converter 安全地转换 ONNX 模型的 Opset 版本。 |
|
|
|
|
|
Args: |
|
|
input_model_path (str): 输入 ONNX 模型文件路径。 |
|
|
output_model_path (str): 输出转换后 ONNX 模型文件路径。 |
|
|
target_opset (int): 目标 Opset 版本号。 |
|
|
""" |
|
|
try: |
|
|
|
|
|
original_model = onnx.load(input_model_path) |
|
|
print(f"原始 Opset Import: {original_model.opset_import}") |
|
|
|
|
|
|
|
|
|
|
|
converted_model = version_converter.convert_version(original_model, target_opset) |
|
|
|
|
|
|
|
|
onnx.checker.check_model(converted_model) |
|
|
print("ONNX Checker 检查转换后的模型通过。") |
|
|
|
|
|
|
|
|
onnx.save(converted_model, output_model_path) |
|
|
print(f"模型已安全转换为 Opset {target_opset} 并保存到: {output_model_path}") |
|
|
print(f"转换后的 Opset Import: {converted_model.opset_import}") |
|
|
|
|
|
except ValueError as e: |
|
|
print(f"转换失败:模型包含无法转换到 Opset {target_opset} 的算子。错误: {e}") |
|
|
except Exception as e: |
|
|
print(f"处理模型时发生错误: {e}") |
|
|
|
|
|
|
|
|
input_model = "F5_Transformer.onnx" |
|
|
output_model_converted = "F5_Transformer_opset19.onnx" |
|
|
target_version = 19 |
|
|
|
|
|
convert_opset(input_model, output_model_converted, target_version) |
|
|
|