FormantNet / quantize_onnx.py
FredrikKarlssonSpeech's picture
Upload quantize_onnx.py with huggingface_hub
ee019e8 verified
#!/usr/bin/env python3
"""
Generate fp16 and int8 variants of formantnet.onnx.
Run in onnxconv env.
"""
import os, sys
import numpy as np
import onnx
import onnxruntime as ort
from onnxconverter_common import float16
from onnxruntime.quantization import quantize_dynamic, QuantType
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
fp32_path = os.path.join(SCRIPT_DIR, 'formantnet.onnx')
fp16_path = os.path.join(SCRIPT_DIR, 'formantnet_fp16.onnx')
int8_path = os.path.join(SCRIPT_DIR, 'formantnet_int8.onnx')
# Reference output for parity checks
np.random.seed(42)
x = np.random.randn(1, 200, 257).astype(np.float32)
ref_sess = ort.InferenceSession(fp32_path, providers=['CPUExecutionProvider'])
input_name = ref_sess.get_inputs()[0].name
ref_out = ref_sess.run(None, {input_name: x})[0]
def check(path, label, abs_thr, rel_thr):
sess = ort.InferenceSession(path, providers=['CPUExecutionProvider'])
iname = sess.get_inputs()[0].name
out = sess.run(None, {iname: x.astype(np.float32)})[0]
max_abs = float(np.abs(ref_out - out).max())
max_rel = float((np.abs(ref_out - out) / (np.abs(ref_out) + 1e-8)).max())
size_mb = os.path.getsize(path) / 1e6
ok = max_abs < abs_thr and max_rel < rel_thr
status = "✓" if ok else "✗"
print(f" {label:6s} {size_mb:.2f} MB max_abs={max_abs:.2e} max_rel={max_rel:.2e} {status}")
return ok
print(f"Source: {fp32_path} ({os.path.getsize(fp32_path)/1e6:.2f} MB)\n")
# --- fp16 ---
print("Generating fp16...")
model_fp32 = onnx.load(fp32_path)
model_fp16 = float16.convert_float_to_float16(model_fp32, keep_io_types=True)
onnx.save(model_fp16, fp16_path)
print(f" Saved → {fp16_path}")
check(fp16_path, 'fp16', abs_thr=0.005, rel_thr=0.05)
# --- int8 dynamic quantization ---
print("\nGenerating int8 (dynamic quantization)...")
quantize_dynamic(fp32_path, int8_path, weight_type=QuantType.QInt8)
print(f" Saved → {int8_path}")
check(int8_path, 'int8', abs_thr=0.15, rel_thr=999.0) # gate on abs only; rel can spike near zero (same as DeepFormants-onnx int8)
print("\nDone.")