"""Compare PyTorch reference vs each ONNX variant on random inputs. Writes metadata.json with max_abs / max_rel diff per variant. Exit non-zero if any threshold exceeded. """ import json import sys from pathlib import Path import numpy as np import torch import torch.nn as nn import onnxruntime as ort ROOT = Path(__file__).resolve().parents[1] REPO = ROOT.parent / "DeepFormants" # Thresholds on the raw (unscaled) model output (output ~ formants/1000). THRESH = { "fp32": {"abs": 1e-4, "rel": 1e-3}, "fp16": {"abs": 5e-3, "rel": 5e-2}, "int8": {"abs": 1.5e-1, "rel": 5e-1}, } class MLP(nn.Module): def __init__(self): super().__init__() self.Dense1 = nn.Linear(350, 1024) self.Dense2 = nn.Linear(1024, 512) self.Dense3 = nn.Linear(512, 256) self.out = nn.Linear(256, 4) def forward(self, x): x = torch.sigmoid(self.Dense1(x)) x = torch.sigmoid(self.Dense2(x)) x = torch.sigmoid(self.Dense3(x)) return self.out(x) class Tracker(nn.Module): def __init__(self): super().__init__() self.lstm1 = nn.LSTM(350, 512, batch_first=True) self.lstm2 = nn.LSTM(512, 256, batch_first=True) self.fc = nn.Linear(256, 4) def forward(self, x): x, _ = self.lstm1(x) x, _ = self.lstm2(x) return self.fc(x) def diff_stats(a, b): a = a.astype(np.float64) b = b.astype(np.float64) abs_d = np.abs(a - b) rel_d = abs_d / (np.abs(a) + 1e-6) return float(abs_d.max()), float(rel_d.max()), float(abs_d.mean()) def run_variant(onnx_path, x_np, input_name="input"): sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) # fp16 model expects float16 input inputs = sess.get_inputs() expected = inputs[0].type x_in = x_np if "float16" in expected: x_in = x_np.astype(np.float16) out = sess.run(None, {inputs[0].name: x_in})[0] return out.astype(np.float32) def validate(name, model, ckpt, onnx_dir, sample_shape, n=20): model.load_state_dict(torch.load(ckpt, map_location="cpu", weights_only=True)) model.eval() torch.manual_seed(0) np.random.seed(0) results = {} fail = False for variant, suffix in [("fp32", "model.onnx"), ("fp16", "model_fp16.onnx"), ("int8", "model_int8.onnx")]: path = onnx_dir / suffix all_abs, all_rel, all_mean = [], [], [] for _ in range(n): x = np.random.randn(*sample_shape).astype(np.float32) with torch.no_grad(): ref = model(torch.from_numpy(x)).numpy() got = run_variant(path.as_posix(), x) a, r, m = diff_stats(ref, got) all_abs.append(a); all_rel.append(r); all_mean.append(m) max_abs = max(all_abs) max_rel = max(all_rel) mean_abs = float(np.mean(all_mean)) size_mb = path.stat().st_size / 1e6 t = THRESH[variant] ok = (max_abs <= t["abs"]) or (max_rel <= t["rel"]) results[variant] = { "file": suffix, "size_mb": round(size_mb, 3), "max_abs_diff": max_abs, "max_rel_diff": max_rel, "mean_abs_diff": mean_abs, "threshold_abs": t["abs"], "threshold_rel": t["rel"], "pass": ok, } status = "OK" if ok else "FAIL" print(f" [{variant}] {status} size={size_mb:.2f}MB " f"max_abs={max_abs:.3e} max_rel={max_rel:.3e} mean_abs={mean_abs:.3e}") if not ok: fail = True return results, fail def main(): print("LPC MLP estimator:") lpc_res, fail1 = validate( "lpc_estimator", MLP(), REPO / "pytorchFormants" / "Estimator" / "LPC_NN_scaledLoss.pt", ROOT / "lpc_estimator", sample_shape=(4, 350), ) print("\nLPC RNN tracker:") trk_res, fail2 = validate( "lpc_tracker", Tracker(), REPO / "pytorchFormants" / "Tracker" / "LPC_RNN.pt", ROOT / "lpc_tracker", sample_shape=(2, 20, 350), ) print("\nLPC MLP estimator (Torch7-origin):") t7_res, fail3 = validate( "lpc_estimator_torch7", MLP(), ROOT / "lpc_estimator_torch7" / "reconstructed.pt", ROOT / "lpc_estimator_torch7", sample_shape=(4, 350), ) print("\nLPC RNN tracker (Torch7-origin):") t7trk_res, fail4 = validate( "lpc_tracker_torch7", Tracker(), ROOT / "lpc_tracker_torch7" / "reconstructed.pt", ROOT / "lpc_tracker_torch7", sample_shape=(2, 20, 350), ) meta = { "models": { "lpc_estimator": { "source": "pytorchFormants/Estimator/LPC_NN_scaledLoss.pt", "architecture": "MLP 350->1024->512->256->4 (sigmoid hidden, linear out)", "input": {"name": "input", "shape": ["batch", 350], "dtype": "float32"}, "output": {"name": "formants", "shape": ["batch", 4], "note": "raw output ~ formant_Hz / 1000 (per repo convention)"}, "opset": 17, "variants": lpc_res, }, "lpc_tracker": { "source": "pytorchFormants/Tracker/LPC_RNN.pt", "architecture": "LSTM(350,512) -> LSTM(512,256) -> Linear(256,4)", "input": {"name": "input", "shape": ["batch", "time", 350], "dtype": "float32"}, "output": {"name": "formants", "shape": ["batch", "time", 4], "note": "raw output ~ formant_Hz / 1000 (per repo convention)"}, "opset": 17, "variants": trk_res, }, "lpc_estimator_torch7": { "source": "estimation_model.dat (Torch7 nn.Sequential, ported via torchfile)", "architecture": "MLP 350->1024->512->256->4 (sigmoid hidden, linear out) — identical to LPC_NN_scaledLoss.pt; different weights", "input": {"name": "input", "shape": ["batch", 350], "dtype": "float32"}, "output": {"name": "formants", "shape": ["batch", 4], "note": "raw output ~ formant_Hz / 1000 (×1000 for Hz, per load_estimation_model.lua)"}, "opset": 17, "variants": t7_res, "port_fidelity_hz": "max 0.003 Hz drift on real features vs float64 numpy reconstruction of Torch7 forward", }, "lpc_tracker_torch7": { "source": "tracking_model.dat (Torch7 nn.Sequential of nn.Sequencer+nn.FastLSTM, ported via torchfile)", "architecture": "LSTM(350,512) -> LSTM(512,256) -> Linear(256,4); identical shape to LPC_RNN.pt; different weights (original paper model)", "input": {"name": "input", "shape": ["batch", "time", 350], "dtype": "float32"}, "output": {"name": "formants", "shape": ["batch", "time", 4], "note": "raw output ~ formant_Hz / 1000"}, "opset": 17, "variants": t7trk_res, "gate_remap": "Torch7 FastLSTM [i,g,f,o] -> PyTorch nn.LSTM [i,f,g,o]; block perm [0,2,1,3]", "bias_convention": "Torch7 i2g.bias -> bias_ih_l0 (permuted); bias_hh_l0 = 0", "port_fidelity_hz": "max 0.0001 Hz drift on random input vs float64 numpy FastLSTM reference forward", }, }, "license": "MIT (DeepFormants repo). Weights derived from MLSpeech/DeepFormants. Local use; redistribution not verified.", "skipped": { "CNN_estimate.pt": "Checkpoint not shipped in the public repo.", }, } with open(ROOT / "metadata.json", "w") as f: json.dump(meta, f, indent=2) print(f"\nWrote {ROOT / 'metadata.json'}") sys.exit(1 if (fail1 or fail2 or fail3 or fail4) else 0) if __name__ == "__main__": main()