File size: 3,466 Bytes
993d81c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
"""Replace HardSigmoid with Mul + Add + Clip in RTMPose ONNX.

Replacing HardSigmoid with standard ops (Mul/Add/Clip)
allows FP32 or U16 quantization on these nodes.

Equivalent: HardSigmoid(x) = Clip(x * alpha + beta, 0, 1)
"""

import argparse

import numpy as np
import onnx
from onnx import TensorProto, helper, numpy_helper


def replace_hardsigmoid(model: onnx.ModelProto) -> int:
    graph = model.graph
    nodes = list(graph.node)

    new_nodes = []
    initializers_to_add = []
    hs_count = 0

    for n in nodes:
        if n.op_type == "HardSigmoid":
            hs_count += 1
            inp = n.input[0]
            out = n.output[0]
            prefix = f"hs_replace_{hs_count}"

            alpha = 0.2
            beta = 0.5
            for attr in n.attribute:
                if attr.name == "alpha":
                    alpha = attr.f
                elif attr.name == "beta":
                    beta = attr.f

            alpha_name = f"{prefix}_alpha"
            beta_name = f"{prefix}_beta"
            min_name = f"{prefix}_min"
            max_name = f"{prefix}_max"

            initializers_to_add.append(
                numpy_helper.from_array(np.array([alpha], dtype=np.float32), alpha_name)
            )
            initializers_to_add.append(
                numpy_helper.from_array(np.array([beta], dtype=np.float32), beta_name)
            )
            initializers_to_add.append(
                numpy_helper.from_array(np.array([0.0], dtype=np.float32), min_name)
            )
            initializers_to_add.append(
                numpy_helper.from_array(np.array([1.0], dtype=np.float32), max_name)
            )

            mul_out = f"{prefix}_mul_out"
            add_out = f"{prefix}_add_out"

            mul_node = helper.make_node("Mul", [inp, alpha_name], [mul_out], name=f"{prefix}_Mul")
            add_node = helper.make_node("Add", [mul_out, beta_name], [add_out], name=f"{prefix}_Add")
            clip_node = helper.make_node("Clip", [add_out, min_name, max_name], [out], name=f"{prefix}_Clip")

            new_nodes.extend([mul_node, add_node, clip_node])
        else:
            new_nodes.append(n)

    del graph.node[:]
    graph.node.extend(new_nodes)
    for init in initializers_to_add:
        graph.initializer.append(init)

    return hs_count


def fix_batch_dim(model: onnx.ModelProto):
    for inp in model.graph.input:
        shape = inp.type.tensor_type.shape
        if shape and shape.dim:
            d0 = shape.dim[0]
            if d0.dim_param or d0.dim_value != 1:
                d0.dim_param = ""
                d0.dim_value = 1


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", default="rtmpose_m_256x192.onnx")
    ap.add_argument("--output", default="rtmpose_m_256x192_no_hs.onnx")
    args = ap.parse_args()

    model = onnx.load(args.input)
    count = replace_hardsigmoid(model)
    print(f"Replaced {count} HardSigmoid -> Mul+Add+Clip")

    fix_batch_dim(model)
    print("Fixed dynamic batch dim -> 1")

    onnx.save(model, args.output)
    print(f"Saved: {args.output}")

    import onnxruntime as ort

    sess = ort.InferenceSession(args.output, providers=["CPUExecutionProvider"])
    inp = sess.get_inputs()[0]
    dummy = np.random.randn(*inp.shape).astype(np.float32)
    outs = sess.run(None, {inp.name: dummy})
    print(f"Verify OK: {[o.shape for o in outs]}")


if __name__ == "__main__":
    main()