| """Compare PyTorch reference vs each ONNX variant on random inputs. |
| |
| Writes metadata.json with max_abs / max_rel diff per variant. |
| Exit non-zero if any threshold exceeded. |
| """ |
| import json |
| import sys |
| from pathlib import Path |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import onnxruntime as ort |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| REPO = ROOT.parent / "DeepFormants" |
|
|
| |
| THRESH = { |
| "fp32": {"abs": 1e-4, "rel": 1e-3}, |
| "fp16": {"abs": 5e-3, "rel": 5e-2}, |
| "int8": {"abs": 1.5e-1, "rel": 5e-1}, |
| } |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.Dense1 = nn.Linear(350, 1024) |
| self.Dense2 = nn.Linear(1024, 512) |
| self.Dense3 = nn.Linear(512, 256) |
| self.out = nn.Linear(256, 4) |
|
|
| def forward(self, x): |
| x = torch.sigmoid(self.Dense1(x)) |
| x = torch.sigmoid(self.Dense2(x)) |
| x = torch.sigmoid(self.Dense3(x)) |
| return self.out(x) |
|
|
|
|
| class Tracker(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lstm1 = nn.LSTM(350, 512, batch_first=True) |
| self.lstm2 = nn.LSTM(512, 256, batch_first=True) |
| self.fc = nn.Linear(256, 4) |
|
|
| def forward(self, x): |
| x, _ = self.lstm1(x) |
| x, _ = self.lstm2(x) |
| return self.fc(x) |
|
|
|
|
| def diff_stats(a, b): |
| a = a.astype(np.float64) |
| b = b.astype(np.float64) |
| abs_d = np.abs(a - b) |
| rel_d = abs_d / (np.abs(a) + 1e-6) |
| return float(abs_d.max()), float(rel_d.max()), float(abs_d.mean()) |
|
|
|
|
| def run_variant(onnx_path, x_np, input_name="input"): |
| sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
| |
| inputs = sess.get_inputs() |
| expected = inputs[0].type |
| x_in = x_np |
| if "float16" in expected: |
| x_in = x_np.astype(np.float16) |
| out = sess.run(None, {inputs[0].name: x_in})[0] |
| return out.astype(np.float32) |
|
|
|
|
| def validate(name, model, ckpt, onnx_dir, sample_shape, n=20): |
| model.load_state_dict(torch.load(ckpt, map_location="cpu", weights_only=True)) |
| model.eval() |
| torch.manual_seed(0) |
| np.random.seed(0) |
|
|
| results = {} |
| fail = False |
| for variant, suffix in [("fp32", "model.onnx"), |
| ("fp16", "model_fp16.onnx"), |
| ("int8", "model_int8.onnx")]: |
| path = onnx_dir / suffix |
| all_abs, all_rel, all_mean = [], [], [] |
| for _ in range(n): |
| x = np.random.randn(*sample_shape).astype(np.float32) |
| with torch.no_grad(): |
| ref = model(torch.from_numpy(x)).numpy() |
| got = run_variant(path.as_posix(), x) |
| a, r, m = diff_stats(ref, got) |
| all_abs.append(a); all_rel.append(r); all_mean.append(m) |
| max_abs = max(all_abs) |
| max_rel = max(all_rel) |
| mean_abs = float(np.mean(all_mean)) |
| size_mb = path.stat().st_size / 1e6 |
| t = THRESH[variant] |
| ok = (max_abs <= t["abs"]) or (max_rel <= t["rel"]) |
| results[variant] = { |
| "file": suffix, |
| "size_mb": round(size_mb, 3), |
| "max_abs_diff": max_abs, |
| "max_rel_diff": max_rel, |
| "mean_abs_diff": mean_abs, |
| "threshold_abs": t["abs"], |
| "threshold_rel": t["rel"], |
| "pass": ok, |
| } |
| status = "OK" if ok else "FAIL" |
| print(f" [{variant}] {status} size={size_mb:.2f}MB " |
| f"max_abs={max_abs:.3e} max_rel={max_rel:.3e} mean_abs={mean_abs:.3e}") |
| if not ok: |
| fail = True |
| return results, fail |
|
|
|
|
| def main(): |
| print("LPC MLP estimator:") |
| lpc_res, fail1 = validate( |
| "lpc_estimator", |
| MLP(), |
| REPO / "pytorchFormants" / "Estimator" / "LPC_NN_scaledLoss.pt", |
| ROOT / "lpc_estimator", |
| sample_shape=(4, 350), |
| ) |
|
|
| print("\nLPC RNN tracker:") |
| trk_res, fail2 = validate( |
| "lpc_tracker", |
| Tracker(), |
| REPO / "pytorchFormants" / "Tracker" / "LPC_RNN.pt", |
| ROOT / "lpc_tracker", |
| sample_shape=(2, 20, 350), |
| ) |
|
|
| print("\nLPC MLP estimator (Torch7-origin):") |
| t7_res, fail3 = validate( |
| "lpc_estimator_torch7", |
| MLP(), |
| ROOT / "lpc_estimator_torch7" / "reconstructed.pt", |
| ROOT / "lpc_estimator_torch7", |
| sample_shape=(4, 350), |
| ) |
|
|
| print("\nLPC RNN tracker (Torch7-origin):") |
| t7trk_res, fail4 = validate( |
| "lpc_tracker_torch7", |
| Tracker(), |
| ROOT / "lpc_tracker_torch7" / "reconstructed.pt", |
| ROOT / "lpc_tracker_torch7", |
| sample_shape=(2, 20, 350), |
| ) |
|
|
| meta = { |
| "models": { |
| "lpc_estimator": { |
| "source": "pytorchFormants/Estimator/LPC_NN_scaledLoss.pt", |
| "architecture": "MLP 350->1024->512->256->4 (sigmoid hidden, linear out)", |
| "input": {"name": "input", "shape": ["batch", 350], "dtype": "float32"}, |
| "output": {"name": "formants", "shape": ["batch", 4], |
| "note": "raw output ~ formant_Hz / 1000 (per repo convention)"}, |
| "opset": 17, |
| "variants": lpc_res, |
| }, |
| "lpc_tracker": { |
| "source": "pytorchFormants/Tracker/LPC_RNN.pt", |
| "architecture": "LSTM(350,512) -> LSTM(512,256) -> Linear(256,4)", |
| "input": {"name": "input", "shape": ["batch", "time", 350], "dtype": "float32"}, |
| "output": {"name": "formants", "shape": ["batch", "time", 4], |
| "note": "raw output ~ formant_Hz / 1000 (per repo convention)"}, |
| "opset": 17, |
| "variants": trk_res, |
| }, |
| "lpc_estimator_torch7": { |
| "source": "estimation_model.dat (Torch7 nn.Sequential, ported via torchfile)", |
| "architecture": "MLP 350->1024->512->256->4 (sigmoid hidden, linear out) — identical to LPC_NN_scaledLoss.pt; different weights", |
| "input": {"name": "input", "shape": ["batch", 350], "dtype": "float32"}, |
| "output": {"name": "formants", "shape": ["batch", 4], |
| "note": "raw output ~ formant_Hz / 1000 (×1000 for Hz, per load_estimation_model.lua)"}, |
| "opset": 17, |
| "variants": t7_res, |
| "port_fidelity_hz": "max 0.003 Hz drift on real features vs float64 numpy reconstruction of Torch7 forward", |
| }, |
| "lpc_tracker_torch7": { |
| "source": "tracking_model.dat (Torch7 nn.Sequential of nn.Sequencer+nn.FastLSTM, ported via torchfile)", |
| "architecture": "LSTM(350,512) -> LSTM(512,256) -> Linear(256,4); identical shape to LPC_RNN.pt; different weights (original paper model)", |
| "input": {"name": "input", "shape": ["batch", "time", 350], "dtype": "float32"}, |
| "output": {"name": "formants", "shape": ["batch", "time", 4], |
| "note": "raw output ~ formant_Hz / 1000"}, |
| "opset": 17, |
| "variants": t7trk_res, |
| "gate_remap": "Torch7 FastLSTM [i,g,f,o] -> PyTorch nn.LSTM [i,f,g,o]; block perm [0,2,1,3]", |
| "bias_convention": "Torch7 i2g.bias -> bias_ih_l0 (permuted); bias_hh_l0 = 0", |
| "port_fidelity_hz": "max 0.0001 Hz drift on random input vs float64 numpy FastLSTM reference forward", |
| }, |
| }, |
| "license": "MIT (DeepFormants repo). Weights derived from MLSpeech/DeepFormants. Local use; redistribution not verified.", |
| "skipped": { |
| "CNN_estimate.pt": "Checkpoint not shipped in the public repo.", |
| }, |
| } |
| with open(ROOT / "metadata.json", "w") as f: |
| json.dump(meta, f, indent=2) |
| print(f"\nWrote {ROOT / 'metadata.json'}") |
| sys.exit(1 if (fail1 or fail2 or fail3 or fail4) else 0) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|