File size: 8,122 Bytes
a37967e 1261e7b a37967e 1261e7b a37967e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | #!/usr/bin/env python3
"""Export a trained zh-TW/en 8k Inflect-Nano (acoustic + snake_8k vocoder) to ONNX.
Config-driven from the train() checkpoints. FastSpeech split:
encoder.onnx -> numpy host_regulate -> decoder.onnx -> vocoder.onnx.
Validates full-pipeline parity vs torch. Run in moss-train-venv."""
from __future__ import annotations
import argparse, sys, math, json
from pathlib import Path
import numpy as np, torch
REPO = "/tmp/inflect-nano"
sys.path.insert(0, REPO)
from inflect_nano.acoustic import MicroFastSpeech, MicroFastSpeechConfig
from inflect_nano.vocoder import HifiGanGenerator, make_config
class EncoderHead(torch.nn.Module):
def __init__(self, m): super().__init__(); self.m = m
def forward(self, phone, tone, lang, speaker):
m = self.m
tok = torch.ones_like(phone, dtype=torch.bool)
enc = m.encode(phone, tone, lang, speaker, tok)
log_dur, energy, bright, pitch = m.predict_prosody(enc, tok)
dur = (torch.exp(log_dur) - 1.0).clamp(0, 80).round().clamp_min(1).long()
cond = enc + m.energy_proj(energy.unsqueeze(-1)) + m.bright_proj(bright.unsqueeze(-1))
pitch = torch.stack([pitch[..., 0], pitch[..., 1].clamp(0, 1)], dim=-1)
return cond, dur, pitch
class DecoderHead(torch.nn.Module):
def __init__(self, m): super().__init__(); self.m = m
def forward(self, frames, frame_meta, local_ctx_raw, abs_pos, pitch_frame, frame_mask):
m = self.m
x = frames + m.frame_proj(frame_meta) + m.local_ctx(local_ctx_raw)
x = x + m.abs_frame(abs_pos)
if m.cfg.use_frame_pitch:
refined = m.refine_frame_pitch(frames, frame_meta, pitch_frame)
x = x + m.pitch_proj(refined)
for blk in m.decoder:
x = blk(x, frame_mask)
x = x + m.frame_gru(x)[0]
mel = m.mel_head(x).transpose(1, 2)
return mel + m.cfg.postnet_scale * m.postnet(mel)
def host_regulate(cond, dur, pitch, abs_bins, max_frames):
c = cond[0]; d = dur[0].astype(np.int64); d[d < 0] = 0
T, H = c.shape
frames = np.repeat(c, d, axis=0); F = frames.shape[0]
tok = np.repeat(np.arange(T), d); starts = np.cumsum(d) - d
within = np.arange(F) - starts[tok]; dpf = d[tok].astype(np.float32)
rel = (within / np.maximum(dpf - 1, 1)).astype(np.float32)
tc = max(1, int((d > 0).sum())); token_pos = (tok / max(1, tc - 1)).astype(np.float32)
ld = (np.log1p(dpf) / 6.0).astype(np.float32); center = 1.0 - np.abs(rel * 2 - 1)
fm = np.stack([rel, 1 - rel, center, np.sin(rel*np.pi), np.cos(rel*np.pi), token_pos, ld, dpf/40.0], -1).astype(np.float32)
prev = np.concatenate([c[:1], c[:-1]], 0); nxt = np.concatenate([c[1:], c[-1:]], 0)
lc = np.repeat(np.concatenate([prev, c, nxt], -1), d, axis=0).astype(np.float32)
pos = np.arange(F); abs_pos = np.minimum(pos*abs_bins//max(1, max_frames), abs_bins-1).astype(np.int64)
pf = np.repeat(pitch[0], d, axis=0).astype(np.float32)
return {"frames": frames[None].astype(np.float32), "frame_meta": fm[None], "local_ctx_raw": lc[None],
"abs_pos": abs_pos[None], "pitch_frame": pf[None], "frame_mask": np.ones((1, F), bool)}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--acoustic-ckpt", required=True)
ap.add_argument("--vocoder-ckpt", required=True)
ap.add_argument("--out-dir", required=True)
ap.add_argument("--symbol-table", default="/home/luigi/jetson-tts/mossnano/zhtw8k/symbol_table.json")
args = ap.parse_args()
import onnxruntime as ort
OUT = Path(args.out_dir); OUT.mkdir(parents=True, exist_ok=True)
dev = torch.device("cpu")
ac = torch.load(args.acoustic_ckpt, map_location=dev, weights_only=False)
cfg = MicroFastSpeechConfig(**ac["config"])
m = MicroFastSpeech(cfg); m.load_state_dict(ac["model"], strict=False); m.eval()
# The group-duration planner uses a non-ONNX-able host loop and only adjusts inference-time
# durations (the mel decoder is trained on GT durations). Disable it so the exported encoder's
# plain-duration path (which keeps the contextual duration-delta) matches m.infer() for parity.
if getattr(m, "group_duration_delta", None) is not None:
m.group_duration_delta = None
print("note: group_duration_planner disabled at export (host-loop; plain durations used)")
enc, dec = EncoderHead(m).eval(), DecoderHead(m).eval()
print(f"acoustic: sr={cfg.sample_rate} vocab={cfg.vocab_size} tone={cfg.tone_size} lang={cfg.lang_size} "
f"abs_bins={cfg.abs_frame_bins} max_frames={cfg.max_frames}")
vc = torch.load(args.vocoder_ckpt, map_location=dev, weights_only=False)
vcfg = make_config(vc["config"]["variant"])
vm = HifiGanGenerator(vcfg); vm.load_state_dict(vc["generator"]); vm.remove_weight_norm(); vm.eval()
assert vcfg.sample_rate == cfg.sample_rate
# sample input: a short valid id sequence (plumbing test)
T = 40
g = torch.Generator().manual_seed(0)
phone = torch.randint(1, min(80, cfg.vocab_size), (1, T), generator=g)
tone = torch.randint(0, cfg.tone_size, (1, T), generator=g)
lang = torch.randint(0, cfg.lang_size, (1, T), generator=g)
spk = torch.zeros(1, dtype=torch.long)
with torch.no_grad():
cond, dur, pitch = enc(phone, tone, lang, spk)
reg = host_regulate(cond.numpy(), dur.numpy(), pitch.numpy(), cfg.abs_frame_bins, cfg.max_frames)
bt = tuple(torch.from_numpy(reg[k]).clone() for k in ["frames","frame_meta","local_ctx_raw","abs_pos","pitch_frame","frame_mask"])
mel_split = dec(*bt)
mel_ref = m.infer(phone, tone, lang, spk)
print(f"mel parity max_abs_diff={float((mel_ref-mel_split).abs().max()):.2e}")
torch.onnx.export(enc, (phone, tone, lang, spk), str(OUT/"acoustic_encoder.onnx"),
input_names=["phone","tone","lang","speaker"], output_names=["conditioned","durations","pitch"],
dynamic_axes={"phone":{1:"T"},"tone":{1:"T"},"lang":{1:"T"},"conditioned":{1:"T"},"durations":{1:"T"},"pitch":{1:"T"}},
opset_version=17, dynamo=False)
bn = ["frames","frame_meta","local_ctx_raw","abs_pos","pitch_frame","frame_mask"]
torch.onnx.export(dec, bt, str(OUT/"acoustic_decoder.onnx"),
input_names=bn, output_names=["mel"],
dynamic_axes={**{n:{1:"F"} for n in bn}, "mel":{2:"F"}}, opset_version=17, dynamo=False)
dummy = torch.randn(1, vcfg.num_mels, 60) # match vocoder mel count (40 for snake_8k40, 80 otherwise)
torch.onnx.export(vm, dummy, str(OUT/"vocoder.onnx"), input_names=["mel"], output_names=["wav"],
dynamic_axes={"mel":{2:"frames"},"wav":{2:"samples"}}, opset_version=17, dynamo=False)
# full-pipeline ONNX parity
sA = ort.InferenceSession(str(OUT/"acoustic_encoder.onnx"), providers=["CPUExecutionProvider"])
sB = ort.InferenceSession(str(OUT/"acoustic_decoder.onnx"), providers=["CPUExecutionProvider"])
sV = ort.InferenceSession(str(OUT/"vocoder.onnx"), providers=["CPUExecutionProvider"])
oc, od, op = sA.run(None, {"phone":phone.numpy(),"tone":tone.numpy(),"lang":lang.numpy(),"speaker":spk.numpy()})
reg2 = host_regulate(oc, od, op, cfg.abs_frame_bins, cfg.max_frames)
feeds = {n:(reg2[n].astype(np.float32) if reg2[n].dtype!=bool else reg2[n]) for n in bn}
feeds["abs_pos"] = reg2["abs_pos"].astype(np.int64)
mel_onnx = sB.run(None, feeds)[0]
wav_onnx = sV.run(None, {"mel": mel_onnx.astype(np.float32)})[0]
with torch.inference_mode(): wav_ref = vm(mel_ref).numpy()
n = min(wav_ref.shape[-1], wav_onnx.shape[-1])
print(f"FULL-PIPELINE wav parity max_abs_diff={float(np.abs(wav_ref[...,:n]-wav_onnx[...,:n]).max()):.2e}")
# save metadata for the Nano runtime
json.dump({"sample_rate":cfg.sample_rate,"abs_frame_bins":cfg.abs_frame_bins,"max_frames":cfg.max_frames,
"hop_size":vcfg.hop_size,"n_mels":cfg.n_mels,"use_frame_pitch":cfg.use_frame_pitch},
open(OUT/"meta.json","w"), indent=1)
print("sizes(KB):", {f.name: f.stat().st_size//1024 for f in OUT.glob("*.onnx")})
print("EXPORT_OK", OUT)
if __name__ == "__main__":
main()
|