| #!/usr/bin/env python3 | |
| import onnx | |
| from onnxruntime.quantization import QuantType, quantize_dynamic | |
| def main(): | |
| onnx_model = onnx.load("model.onnx") | |
| nodes = [n.name for n in onnx_model.graph.node] | |
| nodes_to_exclude = [m for m in nodes if "output" in m] | |
| print(nodes_to_exclude) | |
| quantize_dynamic( | |
| model_input="model.onnx", | |
| model_output="model.int8.onnx", | |
| op_types_to_quantize=["MatMul"], | |
| per_channel=True, | |
| weight_type=QuantType.QUInt8, | |
| nodes_to_exclude=nodes_to_exclude, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |