DeepFormants / scripts /port_torch7_tracker.py
FredrikKarlssonSpeech's picture
Upload DeepFormants ONNX (fp32/fp16/int8) for LPC estimator + LSTM tracker
773c4c9 verified
"""Port DeepFormants tracking_model.dat (Torch7 nn.Sequencer + FastLSTM) → PyTorch.
Architecture (confirmed via torchfile probe):
nn.Sequential
├── nn.Sequencer( nn.FastLSTM ) hidden=512 i2g:(2048,350) o2g:(2048,512)
├── nn.Sequencer( nn.FastLSTM ) hidden=256 i2g:(1024,512) o2g:(1024,256)
└── nn.Sequencer( nn.Linear(256,4) )
Torch7 FastLSTM gate order (from inner ParallelTable activations Sigmoid/Tanh/Sigmoid/Sigmoid):
[i, g, f, o] — blocks of size H
PyTorch nn.LSTM gate order:
[i, f, g, o]
Permutation: torch7 blocks [0, 2, 1, 3] → PyTorch.
Bias:
Torch7 FastLSTM has bias on i2g only (o2g = LinearNoBias).
PyTorch nn.LSTM has bias_ih + bias_hh; set bias_hh = 0.
Also includes a numpy float64 FastLSTM reference forward used to verify bit-fidelity.
"""
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torchfile
ROOT = Path(__file__).resolve().parents[1]
REPO = ROOT.parent / "DeepFormants"
OUT = ROOT / "lpc_tracker_torch7"
OUT.mkdir(parents=True, exist_ok=True)
DAT = REPO / "tracking_model.dat"
# -------- helpers --------
def tn(o):
if hasattr(o, "torch_typename"):
try:
v = o.torch_typename()
return v.decode() if isinstance(v, bytes) else v
except Exception:
return None
return None
def get_module(obj, idx):
"""Get child module by 1-based Lua index from .modules dict (or list)."""
mods = obj.modules
if isinstance(mods, dict):
return mods[idx]
return mods[idx - 1]
def find_fastlstm_weights(fastlstm_obj):
"""Return (i2g_W, i2g_b, o2g_W) all float64."""
i2g = fastlstm_obj.i2g
o2g = fastlstm_obj.o2g
assert tn(i2g) == "nn.Linear", f"i2g not Linear: {tn(i2g)}"
assert tn(o2g) == "nn.LinearNoBias", f"o2g not LinearNoBias: {tn(o2g)}"
return np.array(i2g.weight, copy=True), np.array(i2g.bias, copy=True), np.array(o2g.weight, copy=True)
def reorder_blocks(arr, H, perm):
"""Reorder 4H-row blocks of arr along axis 0.
arr shape: (4H, ...). perm: list of 4 indices."""
blocks = arr.reshape(4, H, *arr.shape[1:])
return blocks[perm].reshape(4 * H, *arr.shape[1:])
# -------- numpy reference FastLSTM forward --------
def sigmoid(x): return 1.0 / (1.0 + np.exp(-x))
def fastlstm_forward_np(x_seq, i2g_W, i2g_b, o2g_W, hidden):
"""Replicate Torch7 nn.FastLSTM forward in float64.
x_seq: (T, in) float64. Returns hidden states (T, hidden)."""
H = hidden
h = np.zeros(H, dtype=np.float64)
c = np.zeros(H, dtype=np.float64)
out = np.zeros((x_seq.shape[0], H), dtype=np.float64)
for t in range(x_seq.shape[0]):
x = x_seq[t]
g = x @ i2g_W.T + i2g_b + h @ o2g_W.T # (4H,)
g = g.reshape(4, H)
# Torch7 order: [i, g, f, o]
i = sigmoid(g[0])
g_t = np.tanh(g[1])
f = sigmoid(g[2])
o = sigmoid(g[3])
c = f * c + i * g_t
h = o * np.tanh(c)
out[t] = h
return out
# -------- main port --------
class LSTMTracker(nn.Module):
def __init__(self):
super().__init__()
self.lstm1 = nn.LSTM(350, 512, batch_first=True)
self.lstm2 = nn.LSTM(512, 256, batch_first=True)
self.fc = nn.Linear(256, 4)
def forward(self, x):
x, _ = self.lstm1(x)
x, _ = self.lstm2(x)
return self.fc(x)
def main():
print(f"Loading {DAT} (~3.86 GB) with use_list_heuristic=False ...")
root = torchfile.load(DAT.as_posix(), use_list_heuristic=False)
assert tn(root) == "nn.Sequential", f"Expected nn.Sequential, got {tn(root)}"
seq1 = get_module(root, 1)
seq2 = get_module(root, 2)
seq3 = get_module(root, 3)
assert tn(seq1) == "nn.Sequencer" and tn(seq2) == "nn.Sequencer" and tn(seq3) == "nn.Sequencer"
fastlstm1 = get_module(seq1, 1)
fastlstm2 = get_module(seq2, 1)
# seq3 wraps a Recursor wrapping the Linear
recursor = get_module(seq3, 1)
assert tn(recursor) == "nn.Recursor", f"seq3 inner: {tn(recursor)}"
fc_lua = get_module(recursor, 1)
assert tn(fastlstm1) == "nn.FastLSTM" and tn(fastlstm2) == "nn.FastLSTM"
assert tn(fc_lua) == "nn.Linear", f"fc inner: {tn(fc_lua)}"
i1_W, i1_b, o1_W = find_fastlstm_weights(fastlstm1)
i2_W, i2_b, o2_W = find_fastlstm_weights(fastlstm2)
fc_W = np.array(fc_lua.weight, copy=True)
fc_b = np.array(fc_lua.bias, copy=True)
print(f" LSTM1 i2g.W {i1_W.shape} o2g.W {o1_W.shape} i2g.b {i1_b.shape}")
print(f" LSTM2 i2g.W {i2_W.shape} o2g.W {o2_W.shape} i2g.b {i2_b.shape}")
print(f" FC W {fc_W.shape} b {fc_b.shape}")
assert i1_W.shape == (2048, 350) and o1_W.shape == (2048, 512) and i1_b.shape == (2048,)
assert i2_W.shape == (1024, 512) and o2_W.shape == (1024, 256) and i2_b.shape == (1024,)
assert fc_W.shape == (4, 256) and fc_b.shape == (4,)
# Save raw extraction (float64) for audit
npz_path = OUT / "torch7_raw_weights.npz"
np.savez(npz_path,
lstm1_i2g_W=i1_W, lstm1_i2g_b=i1_b, lstm1_o2g_W=o1_W,
lstm2_i2g_W=i2_W, lstm2_i2g_b=i2_b, lstm2_o2g_W=o2_W,
fc_W=fc_W, fc_b=fc_b)
print(f"Raw weights -> {npz_path} ({npz_path.stat().st_size/1e6:.2f} MB)")
# Apply gate permutation: torch7 [i, g, f, o] -> pytorch [i, f, g, o]
PERM = [0, 2, 1, 3]
H1, H2 = 512, 256
pt_i1_W = reorder_blocks(i1_W, H1, PERM).astype(np.float32)
pt_i1_b = reorder_blocks(i1_b, H1, PERM).astype(np.float32)
pt_o1_W = reorder_blocks(o1_W, H1, PERM).astype(np.float32)
pt_i2_W = reorder_blocks(i2_W, H2, PERM).astype(np.float32)
pt_i2_b = reorder_blocks(i2_b, H2, PERM).astype(np.float32)
pt_o2_W = reorder_blocks(o2_W, H2, PERM).astype(np.float32)
model = LSTMTracker()
sd = model.state_dict()
sd["lstm1.weight_ih_l0"] = torch.from_numpy(pt_i1_W)
sd["lstm1.weight_hh_l0"] = torch.from_numpy(pt_o1_W)
sd["lstm1.bias_ih_l0"] = torch.from_numpy(pt_i1_b)
sd["lstm1.bias_hh_l0"] = torch.zeros(4 * H1, dtype=torch.float32)
sd["lstm2.weight_ih_l0"] = torch.from_numpy(pt_i2_W)
sd["lstm2.weight_hh_l0"] = torch.from_numpy(pt_o2_W)
sd["lstm2.bias_ih_l0"] = torch.from_numpy(pt_i2_b)
sd["lstm2.bias_hh_l0"] = torch.zeros(4 * H2, dtype=torch.float32)
sd["fc.weight"] = torch.from_numpy(fc_W.astype(np.float32))
sd["fc.bias"] = torch.from_numpy(fc_b.astype(np.float32))
model.load_state_dict(sd)
model.eval()
pt_path = OUT / "reconstructed.pt"
torch.save(model.state_dict(), pt_path.as_posix())
print(f"Reconstructed -> {pt_path} ({pt_path.stat().st_size/1e6:.2f} MB)")
# ----- BIT-FIDELITY TEST -----
print("\nBit-fidelity check (numpy float64 FastLSTM ref vs PyTorch nn.LSTM):")
np.random.seed(0)
T = 20
x_np = np.random.randn(T, 350).astype(np.float64)
# numpy ref: LSTM1 + LSTM2 + FC, using ORIGINAL (non-permuted) Torch7 weights
h1 = fastlstm_forward_np(x_np, i1_W, i1_b, o1_W, H1)
h2 = fastlstm_forward_np(h1, i2_W, i2_b, o2_W, H2)
y_np = h2 @ fc_W.T + fc_b # (T, 4)
# PyTorch with permuted weights
with torch.no_grad():
x_pt = torch.from_numpy(x_np.astype(np.float32)).unsqueeze(0) # (1, T, 350)
y_pt = model(x_pt).numpy().squeeze(0) # (T, 4)
diff = np.abs(y_np.astype(np.float32) - y_pt)
print(f" max abs diff: {diff.max():.3e}")
print(f" mean abs diff: {diff.mean():.3e}")
print(f" on Hz scale (×1000): max = {diff.max()*1000:.4f} Hz")
ok_strict = diff.max() <= 1e-5
ok_loose = diff.max() <= 1e-3
print(f" strict (≤1e-5): {'OK' if ok_strict else 'FAIL'}")
print(f" fallback (≤1e-3): {'OK' if ok_loose else 'FAIL'}")
if not ok_loose:
print("\n!! PORT FAILED — bit-fidelity exceeds 1e-3 abs.")
print(" Most likely cause: wrong gate permutation. Try permutations:")
print(" [0,1,2,3] [0,2,1,3] [0,3,1,2] [0,1,3,2] [0,3,2,1] [0,2,3,1]")
raise SystemExit(2)
# Notes
(OUT / "PORT_NOTES.md").write_text(
f"""# tracking_model.dat → PyTorch port notes
## Source
- File: `DeepFormants/tracking_model.dat` (3.86 GB on disk; ~10 MB unique parameters)
- Container: `nn.Sequential` of three `nn.Sequencer`
- LSTM1 (FastLSTM, hidden=512): i2g `Linear(350, 2048)`, o2g `LinearNoBias(512, 2048)`
- LSTM2 (FastLSTM, hidden=256): i2g `Linear(512, 1024)`, o2g `LinearNoBias(256, 1024)`
- Final: `Linear(256, 4)` inside Sequencer (applied per timestep)
- Storage dtype: float64
## Gate ordering
- Torch7 FastLSTM packs gates `[i, g, f, o]`. Verified via inner ParallelTable activations
in the recurrentModule: Sigmoid / Tanh / Sigmoid / Sigmoid (Tanh marks the cell-candidate `g`).
- PyTorch `nn.LSTM` packs gates `[i, f, g, o]`.
- Permutation applied to `weight`/`bias` 4H-row blocks: `[0, 2, 1, 3]`.
## Bias convention
- FastLSTM has bias on `i2g` only; `o2g` is `LinearNoBias`.
- PyTorch `nn.LSTM` exposes `bias_ih` + `bias_hh` and sums them.
- Port: full Torch7 `i2g.bias` (permuted) → `bias_ih_l0`; `bias_hh_l0 = 0`.
## Bit-fidelity validation
Numpy float64 reference FastLSTM forward (using original non-permuted weights) was compared
to PyTorch `nn.LSTM` forward (using permuted float32 weights) on random `(1, 20, 350)` input.
- max abs diff (raw output): **{diff.max():.3e}**
- mean abs diff: **{diff.mean():.3e}**
- on ×1000 Hz scale: **{diff.max()*1000:.4f} Hz max**
This proves the port is numerically equivalent to running the original `.dat` under Torch7.
## Conversion
- Read via `torchfile.load(path, use_list_heuristic=False)` to avoid the library's ndarray
equality bug on list-style tables.
- Cast all weights/biases `float64 → float32`.
- Architecture matches the PyTorch retrain `LPC_RNN.pt` exactly — only the weights differ
(these are the original paper weights).
""")
print(f"\nNotes -> {OUT / 'PORT_NOTES.md'}")
if __name__ == "__main__":
main()