DeepFormants / scripts /validate_parity.py
FredrikKarlssonSpeech's picture
Upload DeepFormants ONNX (fp32/fp16/int8) for LPC estimator + LSTM tracker
773c4c9 verified
"""Compare PyTorch reference vs each ONNX variant on random inputs.
Writes metadata.json with max_abs / max_rel diff per variant.
Exit non-zero if any threshold exceeded.
"""
import json
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import onnxruntime as ort
ROOT = Path(__file__).resolve().parents[1]
REPO = ROOT.parent / "DeepFormants"
# Thresholds on the raw (unscaled) model output (output ~ formants/1000).
THRESH = {
"fp32": {"abs": 1e-4, "rel": 1e-3},
"fp16": {"abs": 5e-3, "rel": 5e-2},
"int8": {"abs": 1.5e-1, "rel": 5e-1},
}
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.Dense1 = nn.Linear(350, 1024)
self.Dense2 = nn.Linear(1024, 512)
self.Dense3 = nn.Linear(512, 256)
self.out = nn.Linear(256, 4)
def forward(self, x):
x = torch.sigmoid(self.Dense1(x))
x = torch.sigmoid(self.Dense2(x))
x = torch.sigmoid(self.Dense3(x))
return self.out(x)
class Tracker(nn.Module):
def __init__(self):
super().__init__()
self.lstm1 = nn.LSTM(350, 512, batch_first=True)
self.lstm2 = nn.LSTM(512, 256, batch_first=True)
self.fc = nn.Linear(256, 4)
def forward(self, x):
x, _ = self.lstm1(x)
x, _ = self.lstm2(x)
return self.fc(x)
def diff_stats(a, b):
a = a.astype(np.float64)
b = b.astype(np.float64)
abs_d = np.abs(a - b)
rel_d = abs_d / (np.abs(a) + 1e-6)
return float(abs_d.max()), float(rel_d.max()), float(abs_d.mean())
def run_variant(onnx_path, x_np, input_name="input"):
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
# fp16 model expects float16 input
inputs = sess.get_inputs()
expected = inputs[0].type
x_in = x_np
if "float16" in expected:
x_in = x_np.astype(np.float16)
out = sess.run(None, {inputs[0].name: x_in})[0]
return out.astype(np.float32)
def validate(name, model, ckpt, onnx_dir, sample_shape, n=20):
model.load_state_dict(torch.load(ckpt, map_location="cpu", weights_only=True))
model.eval()
torch.manual_seed(0)
np.random.seed(0)
results = {}
fail = False
for variant, suffix in [("fp32", "model.onnx"),
("fp16", "model_fp16.onnx"),
("int8", "model_int8.onnx")]:
path = onnx_dir / suffix
all_abs, all_rel, all_mean = [], [], []
for _ in range(n):
x = np.random.randn(*sample_shape).astype(np.float32)
with torch.no_grad():
ref = model(torch.from_numpy(x)).numpy()
got = run_variant(path.as_posix(), x)
a, r, m = diff_stats(ref, got)
all_abs.append(a); all_rel.append(r); all_mean.append(m)
max_abs = max(all_abs)
max_rel = max(all_rel)
mean_abs = float(np.mean(all_mean))
size_mb = path.stat().st_size / 1e6
t = THRESH[variant]
ok = (max_abs <= t["abs"]) or (max_rel <= t["rel"])
results[variant] = {
"file": suffix,
"size_mb": round(size_mb, 3),
"max_abs_diff": max_abs,
"max_rel_diff": max_rel,
"mean_abs_diff": mean_abs,
"threshold_abs": t["abs"],
"threshold_rel": t["rel"],
"pass": ok,
}
status = "OK" if ok else "FAIL"
print(f" [{variant}] {status} size={size_mb:.2f}MB "
f"max_abs={max_abs:.3e} max_rel={max_rel:.3e} mean_abs={mean_abs:.3e}")
if not ok:
fail = True
return results, fail
def main():
print("LPC MLP estimator:")
lpc_res, fail1 = validate(
"lpc_estimator",
MLP(),
REPO / "pytorchFormants" / "Estimator" / "LPC_NN_scaledLoss.pt",
ROOT / "lpc_estimator",
sample_shape=(4, 350),
)
print("\nLPC RNN tracker:")
trk_res, fail2 = validate(
"lpc_tracker",
Tracker(),
REPO / "pytorchFormants" / "Tracker" / "LPC_RNN.pt",
ROOT / "lpc_tracker",
sample_shape=(2, 20, 350),
)
print("\nLPC MLP estimator (Torch7-origin):")
t7_res, fail3 = validate(
"lpc_estimator_torch7",
MLP(),
ROOT / "lpc_estimator_torch7" / "reconstructed.pt",
ROOT / "lpc_estimator_torch7",
sample_shape=(4, 350),
)
print("\nLPC RNN tracker (Torch7-origin):")
t7trk_res, fail4 = validate(
"lpc_tracker_torch7",
Tracker(),
ROOT / "lpc_tracker_torch7" / "reconstructed.pt",
ROOT / "lpc_tracker_torch7",
sample_shape=(2, 20, 350),
)
meta = {
"models": {
"lpc_estimator": {
"source": "pytorchFormants/Estimator/LPC_NN_scaledLoss.pt",
"architecture": "MLP 350->1024->512->256->4 (sigmoid hidden, linear out)",
"input": {"name": "input", "shape": ["batch", 350], "dtype": "float32"},
"output": {"name": "formants", "shape": ["batch", 4],
"note": "raw output ~ formant_Hz / 1000 (per repo convention)"},
"opset": 17,
"variants": lpc_res,
},
"lpc_tracker": {
"source": "pytorchFormants/Tracker/LPC_RNN.pt",
"architecture": "LSTM(350,512) -> LSTM(512,256) -> Linear(256,4)",
"input": {"name": "input", "shape": ["batch", "time", 350], "dtype": "float32"},
"output": {"name": "formants", "shape": ["batch", "time", 4],
"note": "raw output ~ formant_Hz / 1000 (per repo convention)"},
"opset": 17,
"variants": trk_res,
},
"lpc_estimator_torch7": {
"source": "estimation_model.dat (Torch7 nn.Sequential, ported via torchfile)",
"architecture": "MLP 350->1024->512->256->4 (sigmoid hidden, linear out) — identical to LPC_NN_scaledLoss.pt; different weights",
"input": {"name": "input", "shape": ["batch", 350], "dtype": "float32"},
"output": {"name": "formants", "shape": ["batch", 4],
"note": "raw output ~ formant_Hz / 1000 (×1000 for Hz, per load_estimation_model.lua)"},
"opset": 17,
"variants": t7_res,
"port_fidelity_hz": "max 0.003 Hz drift on real features vs float64 numpy reconstruction of Torch7 forward",
},
"lpc_tracker_torch7": {
"source": "tracking_model.dat (Torch7 nn.Sequential of nn.Sequencer+nn.FastLSTM, ported via torchfile)",
"architecture": "LSTM(350,512) -> LSTM(512,256) -> Linear(256,4); identical shape to LPC_RNN.pt; different weights (original paper model)",
"input": {"name": "input", "shape": ["batch", "time", 350], "dtype": "float32"},
"output": {"name": "formants", "shape": ["batch", "time", 4],
"note": "raw output ~ formant_Hz / 1000"},
"opset": 17,
"variants": t7trk_res,
"gate_remap": "Torch7 FastLSTM [i,g,f,o] -> PyTorch nn.LSTM [i,f,g,o]; block perm [0,2,1,3]",
"bias_convention": "Torch7 i2g.bias -> bias_ih_l0 (permuted); bias_hh_l0 = 0",
"port_fidelity_hz": "max 0.0001 Hz drift on random input vs float64 numpy FastLSTM reference forward",
},
},
"license": "MIT (DeepFormants repo). Weights derived from MLSpeech/DeepFormants. Local use; redistribution not verified.",
"skipped": {
"CNN_estimate.pt": "Checkpoint not shipped in the public repo.",
},
}
with open(ROOT / "metadata.json", "w") as f:
json.dump(meta, f, indent=2)
print(f"\nWrote {ROOT / 'metadata.json'}")
sys.exit(1 if (fail1 or fail2 or fail3 or fail4) else 0)
if __name__ == "__main__":
main()