"""Port DeepFormants tracking_model.dat (Torch7 nn.Sequencer + FastLSTM) → PyTorch. Architecture (confirmed via torchfile probe): nn.Sequential ├── nn.Sequencer( nn.FastLSTM ) hidden=512 i2g:(2048,350) o2g:(2048,512) ├── nn.Sequencer( nn.FastLSTM ) hidden=256 i2g:(1024,512) o2g:(1024,256) └── nn.Sequencer( nn.Linear(256,4) ) Torch7 FastLSTM gate order (from inner ParallelTable activations Sigmoid/Tanh/Sigmoid/Sigmoid): [i, g, f, o] — blocks of size H PyTorch nn.LSTM gate order: [i, f, g, o] Permutation: torch7 blocks [0, 2, 1, 3] → PyTorch. Bias: Torch7 FastLSTM has bias on i2g only (o2g = LinearNoBias). PyTorch nn.LSTM has bias_ih + bias_hh; set bias_hh = 0. Also includes a numpy float64 FastLSTM reference forward used to verify bit-fidelity. """ from pathlib import Path import numpy as np import torch import torch.nn as nn import torchfile ROOT = Path(__file__).resolve().parents[1] REPO = ROOT.parent / "DeepFormants" OUT = ROOT / "lpc_tracker_torch7" OUT.mkdir(parents=True, exist_ok=True) DAT = REPO / "tracking_model.dat" # -------- helpers -------- def tn(o): if hasattr(o, "torch_typename"): try: v = o.torch_typename() return v.decode() if isinstance(v, bytes) else v except Exception: return None return None def get_module(obj, idx): """Get child module by 1-based Lua index from .modules dict (or list).""" mods = obj.modules if isinstance(mods, dict): return mods[idx] return mods[idx - 1] def find_fastlstm_weights(fastlstm_obj): """Return (i2g_W, i2g_b, o2g_W) all float64.""" i2g = fastlstm_obj.i2g o2g = fastlstm_obj.o2g assert tn(i2g) == "nn.Linear", f"i2g not Linear: {tn(i2g)}" assert tn(o2g) == "nn.LinearNoBias", f"o2g not LinearNoBias: {tn(o2g)}" return np.array(i2g.weight, copy=True), np.array(i2g.bias, copy=True), np.array(o2g.weight, copy=True) def reorder_blocks(arr, H, perm): """Reorder 4H-row blocks of arr along axis 0. arr shape: (4H, ...). perm: list of 4 indices.""" blocks = arr.reshape(4, H, *arr.shape[1:]) return blocks[perm].reshape(4 * H, *arr.shape[1:]) # -------- numpy reference FastLSTM forward -------- def sigmoid(x): return 1.0 / (1.0 + np.exp(-x)) def fastlstm_forward_np(x_seq, i2g_W, i2g_b, o2g_W, hidden): """Replicate Torch7 nn.FastLSTM forward in float64. x_seq: (T, in) float64. Returns hidden states (T, hidden).""" H = hidden h = np.zeros(H, dtype=np.float64) c = np.zeros(H, dtype=np.float64) out = np.zeros((x_seq.shape[0], H), dtype=np.float64) for t in range(x_seq.shape[0]): x = x_seq[t] g = x @ i2g_W.T + i2g_b + h @ o2g_W.T # (4H,) g = g.reshape(4, H) # Torch7 order: [i, g, f, o] i = sigmoid(g[0]) g_t = np.tanh(g[1]) f = sigmoid(g[2]) o = sigmoid(g[3]) c = f * c + i * g_t h = o * np.tanh(c) out[t] = h return out # -------- main port -------- class LSTMTracker(nn.Module): def __init__(self): super().__init__() self.lstm1 = nn.LSTM(350, 512, batch_first=True) self.lstm2 = nn.LSTM(512, 256, batch_first=True) self.fc = nn.Linear(256, 4) def forward(self, x): x, _ = self.lstm1(x) x, _ = self.lstm2(x) return self.fc(x) def main(): print(f"Loading {DAT} (~3.86 GB) with use_list_heuristic=False ...") root = torchfile.load(DAT.as_posix(), use_list_heuristic=False) assert tn(root) == "nn.Sequential", f"Expected nn.Sequential, got {tn(root)}" seq1 = get_module(root, 1) seq2 = get_module(root, 2) seq3 = get_module(root, 3) assert tn(seq1) == "nn.Sequencer" and tn(seq2) == "nn.Sequencer" and tn(seq3) == "nn.Sequencer" fastlstm1 = get_module(seq1, 1) fastlstm2 = get_module(seq2, 1) # seq3 wraps a Recursor wrapping the Linear recursor = get_module(seq3, 1) assert tn(recursor) == "nn.Recursor", f"seq3 inner: {tn(recursor)}" fc_lua = get_module(recursor, 1) assert tn(fastlstm1) == "nn.FastLSTM" and tn(fastlstm2) == "nn.FastLSTM" assert tn(fc_lua) == "nn.Linear", f"fc inner: {tn(fc_lua)}" i1_W, i1_b, o1_W = find_fastlstm_weights(fastlstm1) i2_W, i2_b, o2_W = find_fastlstm_weights(fastlstm2) fc_W = np.array(fc_lua.weight, copy=True) fc_b = np.array(fc_lua.bias, copy=True) print(f" LSTM1 i2g.W {i1_W.shape} o2g.W {o1_W.shape} i2g.b {i1_b.shape}") print(f" LSTM2 i2g.W {i2_W.shape} o2g.W {o2_W.shape} i2g.b {i2_b.shape}") print(f" FC W {fc_W.shape} b {fc_b.shape}") assert i1_W.shape == (2048, 350) and o1_W.shape == (2048, 512) and i1_b.shape == (2048,) assert i2_W.shape == (1024, 512) and o2_W.shape == (1024, 256) and i2_b.shape == (1024,) assert fc_W.shape == (4, 256) and fc_b.shape == (4,) # Save raw extraction (float64) for audit npz_path = OUT / "torch7_raw_weights.npz" np.savez(npz_path, lstm1_i2g_W=i1_W, lstm1_i2g_b=i1_b, lstm1_o2g_W=o1_W, lstm2_i2g_W=i2_W, lstm2_i2g_b=i2_b, lstm2_o2g_W=o2_W, fc_W=fc_W, fc_b=fc_b) print(f"Raw weights -> {npz_path} ({npz_path.stat().st_size/1e6:.2f} MB)") # Apply gate permutation: torch7 [i, g, f, o] -> pytorch [i, f, g, o] PERM = [0, 2, 1, 3] H1, H2 = 512, 256 pt_i1_W = reorder_blocks(i1_W, H1, PERM).astype(np.float32) pt_i1_b = reorder_blocks(i1_b, H1, PERM).astype(np.float32) pt_o1_W = reorder_blocks(o1_W, H1, PERM).astype(np.float32) pt_i2_W = reorder_blocks(i2_W, H2, PERM).astype(np.float32) pt_i2_b = reorder_blocks(i2_b, H2, PERM).astype(np.float32) pt_o2_W = reorder_blocks(o2_W, H2, PERM).astype(np.float32) model = LSTMTracker() sd = model.state_dict() sd["lstm1.weight_ih_l0"] = torch.from_numpy(pt_i1_W) sd["lstm1.weight_hh_l0"] = torch.from_numpy(pt_o1_W) sd["lstm1.bias_ih_l0"] = torch.from_numpy(pt_i1_b) sd["lstm1.bias_hh_l0"] = torch.zeros(4 * H1, dtype=torch.float32) sd["lstm2.weight_ih_l0"] = torch.from_numpy(pt_i2_W) sd["lstm2.weight_hh_l0"] = torch.from_numpy(pt_o2_W) sd["lstm2.bias_ih_l0"] = torch.from_numpy(pt_i2_b) sd["lstm2.bias_hh_l0"] = torch.zeros(4 * H2, dtype=torch.float32) sd["fc.weight"] = torch.from_numpy(fc_W.astype(np.float32)) sd["fc.bias"] = torch.from_numpy(fc_b.astype(np.float32)) model.load_state_dict(sd) model.eval() pt_path = OUT / "reconstructed.pt" torch.save(model.state_dict(), pt_path.as_posix()) print(f"Reconstructed -> {pt_path} ({pt_path.stat().st_size/1e6:.2f} MB)") # ----- BIT-FIDELITY TEST ----- print("\nBit-fidelity check (numpy float64 FastLSTM ref vs PyTorch nn.LSTM):") np.random.seed(0) T = 20 x_np = np.random.randn(T, 350).astype(np.float64) # numpy ref: LSTM1 + LSTM2 + FC, using ORIGINAL (non-permuted) Torch7 weights h1 = fastlstm_forward_np(x_np, i1_W, i1_b, o1_W, H1) h2 = fastlstm_forward_np(h1, i2_W, i2_b, o2_W, H2) y_np = h2 @ fc_W.T + fc_b # (T, 4) # PyTorch with permuted weights with torch.no_grad(): x_pt = torch.from_numpy(x_np.astype(np.float32)).unsqueeze(0) # (1, T, 350) y_pt = model(x_pt).numpy().squeeze(0) # (T, 4) diff = np.abs(y_np.astype(np.float32) - y_pt) print(f" max abs diff: {diff.max():.3e}") print(f" mean abs diff: {diff.mean():.3e}") print(f" on Hz scale (×1000): max = {diff.max()*1000:.4f} Hz") ok_strict = diff.max() <= 1e-5 ok_loose = diff.max() <= 1e-3 print(f" strict (≤1e-5): {'OK' if ok_strict else 'FAIL'}") print(f" fallback (≤1e-3): {'OK' if ok_loose else 'FAIL'}") if not ok_loose: print("\n!! PORT FAILED — bit-fidelity exceeds 1e-3 abs.") print(" Most likely cause: wrong gate permutation. Try permutations:") print(" [0,1,2,3] [0,2,1,3] [0,3,1,2] [0,1,3,2] [0,3,2,1] [0,2,3,1]") raise SystemExit(2) # Notes (OUT / "PORT_NOTES.md").write_text( f"""# tracking_model.dat → PyTorch port notes ## Source - File: `DeepFormants/tracking_model.dat` (3.86 GB on disk; ~10 MB unique parameters) - Container: `nn.Sequential` of three `nn.Sequencer` - LSTM1 (FastLSTM, hidden=512): i2g `Linear(350, 2048)`, o2g `LinearNoBias(512, 2048)` - LSTM2 (FastLSTM, hidden=256): i2g `Linear(512, 1024)`, o2g `LinearNoBias(256, 1024)` - Final: `Linear(256, 4)` inside Sequencer (applied per timestep) - Storage dtype: float64 ## Gate ordering - Torch7 FastLSTM packs gates `[i, g, f, o]`. Verified via inner ParallelTable activations in the recurrentModule: Sigmoid / Tanh / Sigmoid / Sigmoid (Tanh marks the cell-candidate `g`). - PyTorch `nn.LSTM` packs gates `[i, f, g, o]`. - Permutation applied to `weight`/`bias` 4H-row blocks: `[0, 2, 1, 3]`. ## Bias convention - FastLSTM has bias on `i2g` only; `o2g` is `LinearNoBias`. - PyTorch `nn.LSTM` exposes `bias_ih` + `bias_hh` and sums them. - Port: full Torch7 `i2g.bias` (permuted) → `bias_ih_l0`; `bias_hh_l0 = 0`. ## Bit-fidelity validation Numpy float64 reference FastLSTM forward (using original non-permuted weights) was compared to PyTorch `nn.LSTM` forward (using permuted float32 weights) on random `(1, 20, 350)` input. - max abs diff (raw output): **{diff.max():.3e}** - mean abs diff: **{diff.mean():.3e}** - on ×1000 Hz scale: **{diff.max()*1000:.4f} Hz max** This proves the port is numerically equivalent to running the original `.dat` under Torch7. ## Conversion - Read via `torchfile.load(path, use_list_heuristic=False)` to avoid the library's ndarray equality bug on list-style tables. - Cast all weights/biases `float64 → float32`. - Architecture matches the PyTorch retrain `LPC_RNN.pt` exactly — only the weights differ (these are the original paper weights). """) print(f"\nNotes -> {OUT / 'PORT_NOTES.md'}") if __name__ == "__main__": main()