Upload quantize_onnx.py with huggingface_hub
Browse files- quantize_onnx.py +53 -0
quantize_onnx.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generate fp16 and int8 variants of formantnet.onnx.
|
| 4 |
+
Run in onnxconv env.
|
| 5 |
+
"""
|
| 6 |
+
import os, sys
|
| 7 |
+
import numpy as np
|
| 8 |
+
import onnx
|
| 9 |
+
import onnxruntime as ort
|
| 10 |
+
from onnxconverter_common import float16
|
| 11 |
+
from onnxruntime.quantization import quantize_dynamic, QuantType
|
| 12 |
+
|
| 13 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
fp32_path = os.path.join(SCRIPT_DIR, 'formantnet.onnx')
|
| 15 |
+
fp16_path = os.path.join(SCRIPT_DIR, 'formantnet_fp16.onnx')
|
| 16 |
+
int8_path = os.path.join(SCRIPT_DIR, 'formantnet_int8.onnx')
|
| 17 |
+
|
| 18 |
+
# Reference output for parity checks
|
| 19 |
+
np.random.seed(42)
|
| 20 |
+
x = np.random.randn(1, 200, 257).astype(np.float32)
|
| 21 |
+
ref_sess = ort.InferenceSession(fp32_path, providers=['CPUExecutionProvider'])
|
| 22 |
+
input_name = ref_sess.get_inputs()[0].name
|
| 23 |
+
ref_out = ref_sess.run(None, {input_name: x})[0]
|
| 24 |
+
|
| 25 |
+
def check(path, label, abs_thr, rel_thr):
|
| 26 |
+
sess = ort.InferenceSession(path, providers=['CPUExecutionProvider'])
|
| 27 |
+
iname = sess.get_inputs()[0].name
|
| 28 |
+
out = sess.run(None, {iname: x.astype(np.float32)})[0]
|
| 29 |
+
max_abs = float(np.abs(ref_out - out).max())
|
| 30 |
+
max_rel = float((np.abs(ref_out - out) / (np.abs(ref_out) + 1e-8)).max())
|
| 31 |
+
size_mb = os.path.getsize(path) / 1e6
|
| 32 |
+
ok = max_abs < abs_thr and max_rel < rel_thr
|
| 33 |
+
status = "✓" if ok else "✗"
|
| 34 |
+
print(f" {label:6s} {size_mb:.2f} MB max_abs={max_abs:.2e} max_rel={max_rel:.2e} {status}")
|
| 35 |
+
return ok
|
| 36 |
+
|
| 37 |
+
print(f"Source: {fp32_path} ({os.path.getsize(fp32_path)/1e6:.2f} MB)\n")
|
| 38 |
+
|
| 39 |
+
# --- fp16 ---
|
| 40 |
+
print("Generating fp16...")
|
| 41 |
+
model_fp32 = onnx.load(fp32_path)
|
| 42 |
+
model_fp16 = float16.convert_float_to_float16(model_fp32, keep_io_types=True)
|
| 43 |
+
onnx.save(model_fp16, fp16_path)
|
| 44 |
+
print(f" Saved → {fp16_path}")
|
| 45 |
+
check(fp16_path, 'fp16', abs_thr=0.005, rel_thr=0.05)
|
| 46 |
+
|
| 47 |
+
# --- int8 dynamic quantization ---
|
| 48 |
+
print("\nGenerating int8 (dynamic quantization)...")
|
| 49 |
+
quantize_dynamic(fp32_path, int8_path, weight_type=QuantType.QInt8)
|
| 50 |
+
print(f" Saved → {int8_path}")
|
| 51 |
+
check(int8_path, 'int8', abs_thr=0.15, rel_thr=999.0) # gate on abs only; rel can spike near zero (same as DeepFormants-onnx int8)
|
| 52 |
+
|
| 53 |
+
print("\nDone.")
|