| | import onnx |
| | from onnx import helper, TensorProto |
| | import numpy as np |
| |
|
| | |
| | |
| | def optimize_sa_model(model_path: str, out_path: str, is_fp16: bool = False): |
| | model = onnx.load(model_path) |
| |
|
| | graph = model.graph |
| |
|
| | old_input_name="orig_im_size" |
| | new_input_name="orig_im_size_shape" |
| | |
| | |
| | old_inputs = {vi.name: vi for vi in graph.input} |
| | assert old_input_name in old_inputs, f"Input {old_input_name} not found" |
| | old_vi = old_inputs[old_input_name] |
| |
|
| | |
| | graph.input.remove(old_vi) |
| | |
| | new_input_vi = helper.make_tensor_value_info(new_input_name, TensorProto.FLOAT, ["height", "width"]) |
| | graph.input.extend([new_input_vi]) |
| |
|
| | |
| | if new_input_name not in [input.name for input in graph.input]: |
| | raise ValueError(f"Input '{new_input_name}' does not exist in the graph inputs.") |
| |
|
| | |
| | shape_output_name = old_input_name |
| | shape_node = helper.make_node( |
| | "Shape", |
| | inputs=[new_input_name], |
| | outputs=[shape_output_name], |
| | name="shape_of_orig_im_size_shape" |
| | ) |
| | |
| | graph.node.insert(0, shape_node) |
| |
|
| | |
| | |
| |
|
| | if is_fp16: |
| | |
| | |
| | fp16_constants = ["/Constant_85", "/Constant_86"] |
| | |
| | |
| | for node in graph.node: |
| | if node.op_type == "Constant" and node.name in fp16_constants: |
| | print(node.name) |
| | |
| | for attr in node.attribute: |
| | if attr.name == "value": |
| | |
| | tensor = onnx.numpy_helper.to_array(attr.t) |
| | |
| | |
| | new_tensor = tensor.astype(np.float32) |
| | |
| | |
| | attr.t.CopyFrom(onnx.numpy_helper.from_array(new_tensor)) |
| | break |
| | else: |
| | raise ValueError(f"Constant node '{node.name}' does not have a 'value' attribute.") |
| | |
| | fp16_nodes = ["/ReduceMax", "/Reciprocal", "/Mul_19", "/Mul_20", "/Add_11", "/Floor"] |
| | |
| | for node in graph.node: |
| | if node.name in fp16_nodes: |
| | print(f"Processing node: {node.name}") |
| | for input_name in node.input: |
| | for value_info in graph.value_info: |
| | if value_info.name == input_name: |
| | value_info.type.tensor_type.elem_type = TensorProto.FLOAT |
| | print(f" - Change input: {input_name} to fp32") |
| |
|
| | for output_name in node.output: |
| | for value_info in graph.value_info: |
| | if value_info.name == output_name: |
| | value_info.type.tensor_type.elem_type = TensorProto.FLOAT |
| | print(f" - Change output: {output_name} to fp32") |
| |
|
| | |
| | for node in graph.node: |
| | if node.name == "/Cast_9": |
| | node.attribute[0].i = TensorProto.FLOAT |
| | print(f"Changed /Cast_9 to fp32") |
| | break |
| | onnx.checker.check_model(model) |
| | onnx.save(model, out_path) |
| | print(f"Saved to {out_path}") |
| |
|
| | |
| | |
| | |
| | optimize_sa_model("sam_vit_b_01ec64.decoder-fp16.onnx", "sam_vit_b_01ec64.decoder-orig-img-size-dynamic-fp16.onnx", True) |