File size: 4,621 Bytes
96a90aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import onnx
from onnx import helper, TensorProto
import numpy as np

# Optimize the SAM decoder model to accept dynamic original image size input
# to enable constant folding with freeDimensionOverride option in onnxruntime.
def optimize_sa_model(model_path: str, out_path: str, is_fp16: bool = False):
    model = onnx.load(model_path)

    graph = model.graph

    old_input_name="orig_im_size"
    new_input_name="orig_im_size_shape"
    
    # 1) Find old input
    old_inputs = {vi.name: vi for vi in graph.input}
    assert old_input_name in old_inputs, f"Input {old_input_name} not found"
    old_vi = old_inputs[old_input_name]

    # 2) Remove old input and add new input
    graph.input.remove(old_vi)
    # new input with shape: [height, width]
    new_input_vi = helper.make_tensor_value_info(new_input_name, TensorProto.FLOAT, ["height", "width"])
    graph.input.extend([new_input_vi])

    # Check if new_input_name exists in the graph inputs
    if new_input_name not in [input.name for input in graph.input]:
        raise ValueError(f"Input '{new_input_name}' does not exist in the graph inputs.")

    # 3) Insert Shape node: Shape(X) -> shape_X(INT64 1D tensor [H, W])
    shape_output_name = old_input_name # keep the same name as old input
    shape_node = helper.make_node(
        "Shape",
        inputs=[new_input_name],
        outputs=[shape_output_name],
        name="shape_of_orig_im_size_shape"
    )
    # Insert the Shape node
    graph.node.insert(0, shape_node)

    # 4) The origin input dtype is not INT64, need to add Cast node
    #    But the original model has already a Cast node after the input, ignore it

    if is_fp16:
        # For fp16 model, since CPU kernel doesn't support constant folding
        # for fp16 data type, we need to convert some fp16 constants and input/output info to fp32
        fp16_constants = ["/Constant_85", "/Constant_86"]
        
        # Convert fp16 constants in fp16_constants to fp32
        for node in graph.node:
            if node.op_type == "Constant" and node.name in fp16_constants:
                print(node.name)
                # Locate the "value" attribute of the Constant node
                for attr in node.attribute:
                    if attr.name == "value":
                        # Extract the tensor value
                        tensor = onnx.numpy_helper.to_array(attr.t)
                        
                        # Convert the tensor to the target data type
                        new_tensor = tensor.astype(np.float32)
                        
                        # Create a new ONNX tensor with the updated data type
                        attr.t.CopyFrom(onnx.numpy_helper.from_array(new_tensor))
                        break
                    else:
                        raise ValueError(f"Constant node '{node.name}' does not have a 'value' attribute.")
        
        fp16_nodes = ["/ReduceMax", "/Reciprocal", "/Mul_19", "/Mul_20", "/Add_11", "/Floor"]
        # Change fp16 nodes in fp16_nodes to fp32
        for node in graph.node:
            if node.name in fp16_nodes:
                print(f"Processing node: {node.name}")
                for input_name in node.input:
                    for value_info in graph.value_info:
                        if value_info.name == input_name:
                            value_info.type.tensor_type.elem_type = TensorProto.FLOAT
                            print(f" - Change input: {input_name} to fp32")

                for output_name in node.output:
                    for value_info in graph.value_info:
                        if value_info.name == output_name:
                            value_info.type.tensor_type.elem_type = TensorProto.FLOAT
                            print(f" - Change output: {output_name} to fp32")

        # Change /Cast_9 to fp32
        for node in graph.node:
            if node.name == "/Cast_9":
                node.attribute[0].i = TensorProto.FLOAT
                print(f"Changed /Cast_9 to fp32")
                break
    onnx.checker.check_model(model)
    onnx.save(model, out_path)
    print(f"Saved to {out_path}")

# the original int8 decoder model: https://huggingface.co/schmuell/sam-b-fp16/blob/main/sam_vit_b-decoder-int8.onnx
# optimize_sa_model("sam_vit_b-decoder-int8.onnx", "sam_vit_b-decoder-int8-orig-img-size-dynamic.onnx", False)
# the original fp32 decoder model: https://huggingface.co/schmuell/sam-b-fp16/blob/main/sam_vit_b_01ec64.decoder.onnx
optimize_sa_model("sam_vit_b_01ec64.decoder-fp16.onnx", "sam_vit_b_01ec64.decoder-orig-img-size-dynamic-fp16.onnx", True)