File size: 7,979 Bytes
a4bd2c0 773c4c9 a4bd2c0 827925c 773c4c9 a4bd2c0 827925c 773c4c9 a4bd2c0 773c4c9 a4bd2c0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | """Compare PyTorch reference vs each ONNX variant on random inputs.
Writes metadata.json with max_abs / max_rel diff per variant.
Exit non-zero if any threshold exceeded.
"""
import json
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import onnxruntime as ort
ROOT = Path(__file__).resolve().parents[1]
REPO = ROOT.parent / "DeepFormants"
# Thresholds on the raw (unscaled) model output (output ~ formants/1000).
THRESH = {
"fp32": {"abs": 1e-4, "rel": 1e-3},
"fp16": {"abs": 5e-3, "rel": 5e-2},
"int8": {"abs": 1.5e-1, "rel": 5e-1},
}
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.Dense1 = nn.Linear(350, 1024)
self.Dense2 = nn.Linear(1024, 512)
self.Dense3 = nn.Linear(512, 256)
self.out = nn.Linear(256, 4)
def forward(self, x):
x = torch.sigmoid(self.Dense1(x))
x = torch.sigmoid(self.Dense2(x))
x = torch.sigmoid(self.Dense3(x))
return self.out(x)
class Tracker(nn.Module):
def __init__(self):
super().__init__()
self.lstm1 = nn.LSTM(350, 512, batch_first=True)
self.lstm2 = nn.LSTM(512, 256, batch_first=True)
self.fc = nn.Linear(256, 4)
def forward(self, x):
x, _ = self.lstm1(x)
x, _ = self.lstm2(x)
return self.fc(x)
def diff_stats(a, b):
a = a.astype(np.float64)
b = b.astype(np.float64)
abs_d = np.abs(a - b)
rel_d = abs_d / (np.abs(a) + 1e-6)
return float(abs_d.max()), float(rel_d.max()), float(abs_d.mean())
def run_variant(onnx_path, x_np, input_name="input"):
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
# fp16 model expects float16 input
inputs = sess.get_inputs()
expected = inputs[0].type
x_in = x_np
if "float16" in expected:
x_in = x_np.astype(np.float16)
out = sess.run(None, {inputs[0].name: x_in})[0]
return out.astype(np.float32)
def validate(name, model, ckpt, onnx_dir, sample_shape, n=20):
model.load_state_dict(torch.load(ckpt, map_location="cpu", weights_only=True))
model.eval()
torch.manual_seed(0)
np.random.seed(0)
results = {}
fail = False
for variant, suffix in [("fp32", "model.onnx"),
("fp16", "model_fp16.onnx"),
("int8", "model_int8.onnx")]:
path = onnx_dir / suffix
all_abs, all_rel, all_mean = [], [], []
for _ in range(n):
x = np.random.randn(*sample_shape).astype(np.float32)
with torch.no_grad():
ref = model(torch.from_numpy(x)).numpy()
got = run_variant(path.as_posix(), x)
a, r, m = diff_stats(ref, got)
all_abs.append(a); all_rel.append(r); all_mean.append(m)
max_abs = max(all_abs)
max_rel = max(all_rel)
mean_abs = float(np.mean(all_mean))
size_mb = path.stat().st_size / 1e6
t = THRESH[variant]
ok = (max_abs <= t["abs"]) or (max_rel <= t["rel"])
results[variant] = {
"file": suffix,
"size_mb": round(size_mb, 3),
"max_abs_diff": max_abs,
"max_rel_diff": max_rel,
"mean_abs_diff": mean_abs,
"threshold_abs": t["abs"],
"threshold_rel": t["rel"],
"pass": ok,
}
status = "OK" if ok else "FAIL"
print(f" [{variant}] {status} size={size_mb:.2f}MB "
f"max_abs={max_abs:.3e} max_rel={max_rel:.3e} mean_abs={mean_abs:.3e}")
if not ok:
fail = True
return results, fail
def main():
print("LPC MLP estimator:")
lpc_res, fail1 = validate(
"lpc_estimator",
MLP(),
REPO / "pytorchFormants" / "Estimator" / "LPC_NN_scaledLoss.pt",
ROOT / "lpc_estimator",
sample_shape=(4, 350),
)
print("\nLPC RNN tracker:")
trk_res, fail2 = validate(
"lpc_tracker",
Tracker(),
REPO / "pytorchFormants" / "Tracker" / "LPC_RNN.pt",
ROOT / "lpc_tracker",
sample_shape=(2, 20, 350),
)
print("\nLPC MLP estimator (Torch7-origin):")
t7_res, fail3 = validate(
"lpc_estimator_torch7",
MLP(),
ROOT / "lpc_estimator_torch7" / "reconstructed.pt",
ROOT / "lpc_estimator_torch7",
sample_shape=(4, 350),
)
print("\nLPC RNN tracker (Torch7-origin):")
t7trk_res, fail4 = validate(
"lpc_tracker_torch7",
Tracker(),
ROOT / "lpc_tracker_torch7" / "reconstructed.pt",
ROOT / "lpc_tracker_torch7",
sample_shape=(2, 20, 350),
)
meta = {
"models": {
"lpc_estimator": {
"source": "pytorchFormants/Estimator/LPC_NN_scaledLoss.pt",
"architecture": "MLP 350->1024->512->256->4 (sigmoid hidden, linear out)",
"input": {"name": "input", "shape": ["batch", 350], "dtype": "float32"},
"output": {"name": "formants", "shape": ["batch", 4],
"note": "raw output ~ formant_Hz / 1000 (per repo convention)"},
"opset": 17,
"variants": lpc_res,
},
"lpc_tracker": {
"source": "pytorchFormants/Tracker/LPC_RNN.pt",
"architecture": "LSTM(350,512) -> LSTM(512,256) -> Linear(256,4)",
"input": {"name": "input", "shape": ["batch", "time", 350], "dtype": "float32"},
"output": {"name": "formants", "shape": ["batch", "time", 4],
"note": "raw output ~ formant_Hz / 1000 (per repo convention)"},
"opset": 17,
"variants": trk_res,
},
"lpc_estimator_torch7": {
"source": "estimation_model.dat (Torch7 nn.Sequential, ported via torchfile)",
"architecture": "MLP 350->1024->512->256->4 (sigmoid hidden, linear out) — identical to LPC_NN_scaledLoss.pt; different weights",
"input": {"name": "input", "shape": ["batch", 350], "dtype": "float32"},
"output": {"name": "formants", "shape": ["batch", 4],
"note": "raw output ~ formant_Hz / 1000 (×1000 for Hz, per load_estimation_model.lua)"},
"opset": 17,
"variants": t7_res,
"port_fidelity_hz": "max 0.003 Hz drift on real features vs float64 numpy reconstruction of Torch7 forward",
},
"lpc_tracker_torch7": {
"source": "tracking_model.dat (Torch7 nn.Sequential of nn.Sequencer+nn.FastLSTM, ported via torchfile)",
"architecture": "LSTM(350,512) -> LSTM(512,256) -> Linear(256,4); identical shape to LPC_RNN.pt; different weights (original paper model)",
"input": {"name": "input", "shape": ["batch", "time", 350], "dtype": "float32"},
"output": {"name": "formants", "shape": ["batch", "time", 4],
"note": "raw output ~ formant_Hz / 1000"},
"opset": 17,
"variants": t7trk_res,
"gate_remap": "Torch7 FastLSTM [i,g,f,o] -> PyTorch nn.LSTM [i,f,g,o]; block perm [0,2,1,3]",
"bias_convention": "Torch7 i2g.bias -> bias_ih_l0 (permuted); bias_hh_l0 = 0",
"port_fidelity_hz": "max 0.0001 Hz drift on random input vs float64 numpy FastLSTM reference forward",
},
},
"license": "MIT (DeepFormants repo). Weights derived from MLSpeech/DeepFormants. Local use; redistribution not verified.",
"skipped": {
"CNN_estimate.pt": "Checkpoint not shipped in the public repo.",
},
}
with open(ROOT / "metadata.json", "w") as f:
json.dump(meta, f, indent=2)
print(f"\nWrote {ROOT / 'metadata.json'}")
sys.exit(1 if (fail1 or fail2 or fail3 or fail4) else 0)
if __name__ == "__main__":
main()
|