RTMPose / replace_hardsigmoid.py
fangmingguo's picture
Upload 13 files
993d81c verified
#!/usr/bin/env python3
"""Replace HardSigmoid with Mul + Add + Clip in RTMPose ONNX.
Replacing HardSigmoid with standard ops (Mul/Add/Clip)
allows FP32 or U16 quantization on these nodes.
Equivalent: HardSigmoid(x) = Clip(x * alpha + beta, 0, 1)
"""
import argparse
import numpy as np
import onnx
from onnx import TensorProto, helper, numpy_helper
def replace_hardsigmoid(model: onnx.ModelProto) -> int:
graph = model.graph
nodes = list(graph.node)
new_nodes = []
initializers_to_add = []
hs_count = 0
for n in nodes:
if n.op_type == "HardSigmoid":
hs_count += 1
inp = n.input[0]
out = n.output[0]
prefix = f"hs_replace_{hs_count}"
alpha = 0.2
beta = 0.5
for attr in n.attribute:
if attr.name == "alpha":
alpha = attr.f
elif attr.name == "beta":
beta = attr.f
alpha_name = f"{prefix}_alpha"
beta_name = f"{prefix}_beta"
min_name = f"{prefix}_min"
max_name = f"{prefix}_max"
initializers_to_add.append(
numpy_helper.from_array(np.array([alpha], dtype=np.float32), alpha_name)
)
initializers_to_add.append(
numpy_helper.from_array(np.array([beta], dtype=np.float32), beta_name)
)
initializers_to_add.append(
numpy_helper.from_array(np.array([0.0], dtype=np.float32), min_name)
)
initializers_to_add.append(
numpy_helper.from_array(np.array([1.0], dtype=np.float32), max_name)
)
mul_out = f"{prefix}_mul_out"
add_out = f"{prefix}_add_out"
mul_node = helper.make_node("Mul", [inp, alpha_name], [mul_out], name=f"{prefix}_Mul")
add_node = helper.make_node("Add", [mul_out, beta_name], [add_out], name=f"{prefix}_Add")
clip_node = helper.make_node("Clip", [add_out, min_name, max_name], [out], name=f"{prefix}_Clip")
new_nodes.extend([mul_node, add_node, clip_node])
else:
new_nodes.append(n)
del graph.node[:]
graph.node.extend(new_nodes)
for init in initializers_to_add:
graph.initializer.append(init)
return hs_count
def fix_batch_dim(model: onnx.ModelProto):
for inp in model.graph.input:
shape = inp.type.tensor_type.shape
if shape and shape.dim:
d0 = shape.dim[0]
if d0.dim_param or d0.dim_value != 1:
d0.dim_param = ""
d0.dim_value = 1
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--input", default="rtmpose_m_256x192.onnx")
ap.add_argument("--output", default="rtmpose_m_256x192_no_hs.onnx")
args = ap.parse_args()
model = onnx.load(args.input)
count = replace_hardsigmoid(model)
print(f"Replaced {count} HardSigmoid -> Mul+Add+Clip")
fix_batch_dim(model)
print("Fixed dynamic batch dim -> 1")
onnx.save(model, args.output)
print(f"Saved: {args.output}")
import onnxruntime as ort
sess = ort.InferenceSession(args.output, providers=["CPUExecutionProvider"])
inp = sess.get_inputs()[0]
dummy = np.random.randn(*inp.shape).astype(np.float32)
outs = sess.run(None, {inp.name: dummy})
print(f"Verify OK: {[o.shape for o in outs]}")
if __name__ == "__main__":
main()