FredrikKarlssonSpeech commited on
Commit
ee019e8
·
verified ·
1 Parent(s): 8cd9035

Upload quantize_onnx.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. quantize_onnx.py +53 -0
quantize_onnx.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate fp16 and int8 variants of formantnet.onnx.
4
+ Run in onnxconv env.
5
+ """
6
+ import os, sys
7
+ import numpy as np
8
+ import onnx
9
+ import onnxruntime as ort
10
+ from onnxconverter_common import float16
11
+ from onnxruntime.quantization import quantize_dynamic, QuantType
12
+
13
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
14
+ fp32_path = os.path.join(SCRIPT_DIR, 'formantnet.onnx')
15
+ fp16_path = os.path.join(SCRIPT_DIR, 'formantnet_fp16.onnx')
16
+ int8_path = os.path.join(SCRIPT_DIR, 'formantnet_int8.onnx')
17
+
18
+ # Reference output for parity checks
19
+ np.random.seed(42)
20
+ x = np.random.randn(1, 200, 257).astype(np.float32)
21
+ ref_sess = ort.InferenceSession(fp32_path, providers=['CPUExecutionProvider'])
22
+ input_name = ref_sess.get_inputs()[0].name
23
+ ref_out = ref_sess.run(None, {input_name: x})[0]
24
+
25
+ def check(path, label, abs_thr, rel_thr):
26
+ sess = ort.InferenceSession(path, providers=['CPUExecutionProvider'])
27
+ iname = sess.get_inputs()[0].name
28
+ out = sess.run(None, {iname: x.astype(np.float32)})[0]
29
+ max_abs = float(np.abs(ref_out - out).max())
30
+ max_rel = float((np.abs(ref_out - out) / (np.abs(ref_out) + 1e-8)).max())
31
+ size_mb = os.path.getsize(path) / 1e6
32
+ ok = max_abs < abs_thr and max_rel < rel_thr
33
+ status = "✓" if ok else "✗"
34
+ print(f" {label:6s} {size_mb:.2f} MB max_abs={max_abs:.2e} max_rel={max_rel:.2e} {status}")
35
+ return ok
36
+
37
+ print(f"Source: {fp32_path} ({os.path.getsize(fp32_path)/1e6:.2f} MB)\n")
38
+
39
+ # --- fp16 ---
40
+ print("Generating fp16...")
41
+ model_fp32 = onnx.load(fp32_path)
42
+ model_fp16 = float16.convert_float_to_float16(model_fp32, keep_io_types=True)
43
+ onnx.save(model_fp16, fp16_path)
44
+ print(f" Saved → {fp16_path}")
45
+ check(fp16_path, 'fp16', abs_thr=0.005, rel_thr=0.05)
46
+
47
+ # --- int8 dynamic quantization ---
48
+ print("\nGenerating int8 (dynamic quantization)...")
49
+ quantize_dynamic(fp32_path, int8_path, weight_type=QuantType.QInt8)
50
+ print(f" Saved → {int8_path}")
51
+ check(int8_path, 'int8', abs_thr=0.15, rel_thr=999.0) # gate on abs only; rel can spike near zero (same as DeepFormants-onnx int8)
52
+
53
+ print("\nDone.")