#!/usr/bin/env python3 """Export a trained zh-TW/en 8k Inflect-Nano (acoustic + snake_8k vocoder) to ONNX. Config-driven from the train() checkpoints. FastSpeech split: encoder.onnx -> numpy host_regulate -> decoder.onnx -> vocoder.onnx. Validates full-pipeline parity vs torch. Run in moss-train-venv.""" from __future__ import annotations import argparse, sys, math, json from pathlib import Path import numpy as np, torch REPO = "/tmp/inflect-nano" sys.path.insert(0, REPO) from inflect_nano.acoustic import MicroFastSpeech, MicroFastSpeechConfig from inflect_nano.vocoder import HifiGanGenerator, make_config class EncoderHead(torch.nn.Module): def __init__(self, m): super().__init__(); self.m = m def forward(self, phone, tone, lang, speaker): m = self.m tok = torch.ones_like(phone, dtype=torch.bool) enc = m.encode(phone, tone, lang, speaker, tok) log_dur, energy, bright, pitch = m.predict_prosody(enc, tok) dur = (torch.exp(log_dur) - 1.0).clamp(0, 80).round().clamp_min(1).long() cond = enc + m.energy_proj(energy.unsqueeze(-1)) + m.bright_proj(bright.unsqueeze(-1)) pitch = torch.stack([pitch[..., 0], pitch[..., 1].clamp(0, 1)], dim=-1) return cond, dur, pitch class DecoderHead(torch.nn.Module): def __init__(self, m): super().__init__(); self.m = m def forward(self, frames, frame_meta, local_ctx_raw, abs_pos, pitch_frame, frame_mask): m = self.m x = frames + m.frame_proj(frame_meta) + m.local_ctx(local_ctx_raw) x = x + m.abs_frame(abs_pos) if m.cfg.use_frame_pitch: x = x + m.pitch_proj(pitch_frame) for blk in m.decoder: x = blk(x, frame_mask) x = x + m.frame_gru(x)[0] mel = m.mel_head(x).transpose(1, 2) return mel + m.cfg.postnet_scale * m.postnet(mel) def host_regulate(cond, dur, pitch, abs_bins, max_frames): c = cond[0]; d = dur[0].astype(np.int64); d[d < 0] = 0 T, H = c.shape frames = np.repeat(c, d, axis=0); F = frames.shape[0] tok = np.repeat(np.arange(T), d); starts = np.cumsum(d) - d within = np.arange(F) - starts[tok]; dpf = d[tok].astype(np.float32) rel = (within / np.maximum(dpf - 1, 1)).astype(np.float32) tc = max(1, int((d > 0).sum())); token_pos = (tok / max(1, tc - 1)).astype(np.float32) ld = (np.log1p(dpf) / 6.0).astype(np.float32); center = 1.0 - np.abs(rel * 2 - 1) fm = np.stack([rel, 1 - rel, center, np.sin(rel*np.pi), np.cos(rel*np.pi), token_pos, ld, dpf/40.0], -1).astype(np.float32) prev = np.concatenate([c[:1], c[:-1]], 0); nxt = np.concatenate([c[1:], c[-1:]], 0) lc = np.repeat(np.concatenate([prev, c, nxt], -1), d, axis=0).astype(np.float32) pos = np.arange(F); abs_pos = np.minimum(pos*abs_bins//max(1, max_frames), abs_bins-1).astype(np.int64) pf = np.repeat(pitch[0], d, axis=0).astype(np.float32) return {"frames": frames[None].astype(np.float32), "frame_meta": fm[None], "local_ctx_raw": lc[None], "abs_pos": abs_pos[None], "pitch_frame": pf[None], "frame_mask": np.ones((1, F), bool)} def main(): ap = argparse.ArgumentParser() ap.add_argument("--acoustic-ckpt", required=True) ap.add_argument("--vocoder-ckpt", required=True) ap.add_argument("--out-dir", required=True) ap.add_argument("--symbol-table", default="/home/luigi/jetson-tts/mossnano/zhtw8k/symbol_table.json") args = ap.parse_args() import onnxruntime as ort OUT = Path(args.out_dir); OUT.mkdir(parents=True, exist_ok=True) dev = torch.device("cpu") ac = torch.load(args.acoustic_ckpt, map_location=dev, weights_only=False) cfg = MicroFastSpeechConfig(**ac["config"]) m = MicroFastSpeech(cfg); m.load_state_dict(ac["model"]); m.eval() enc, dec = EncoderHead(m).eval(), DecoderHead(m).eval() print(f"acoustic: sr={cfg.sample_rate} vocab={cfg.vocab_size} tone={cfg.tone_size} lang={cfg.lang_size} " f"abs_bins={cfg.abs_frame_bins} max_frames={cfg.max_frames}") vc = torch.load(args.vocoder_ckpt, map_location=dev, weights_only=False) vcfg = make_config(vc["config"]["variant"]) vm = HifiGanGenerator(vcfg); vm.load_state_dict(vc["generator"]); vm.remove_weight_norm(); vm.eval() assert vcfg.sample_rate == cfg.sample_rate # sample input: a short valid id sequence (plumbing test) T = 40 g = torch.Generator().manual_seed(0) phone = torch.randint(1, min(80, cfg.vocab_size), (1, T), generator=g) tone = torch.randint(0, cfg.tone_size, (1, T), generator=g) lang = torch.randint(0, cfg.lang_size, (1, T), generator=g) spk = torch.zeros(1, dtype=torch.long) with torch.no_grad(): cond, dur, pitch = enc(phone, tone, lang, spk) reg = host_regulate(cond.numpy(), dur.numpy(), pitch.numpy(), cfg.abs_frame_bins, cfg.max_frames) bt = tuple(torch.from_numpy(reg[k]).clone() for k in ["frames","frame_meta","local_ctx_raw","abs_pos","pitch_frame","frame_mask"]) mel_split = dec(*bt) mel_ref = m.infer(phone, tone, lang, spk) print(f"mel parity max_abs_diff={float((mel_ref-mel_split).abs().max()):.2e}") torch.onnx.export(enc, (phone, tone, lang, spk), str(OUT/"acoustic_encoder.onnx"), input_names=["phone","tone","lang","speaker"], output_names=["conditioned","durations","pitch"], dynamic_axes={"phone":{1:"T"},"tone":{1:"T"},"lang":{1:"T"},"conditioned":{1:"T"},"durations":{1:"T"},"pitch":{1:"T"}}, opset_version=17, dynamo=False) bn = ["frames","frame_meta","local_ctx_raw","abs_pos","pitch_frame","frame_mask"] torch.onnx.export(dec, bt, str(OUT/"acoustic_decoder.onnx"), input_names=bn, output_names=["mel"], dynamic_axes={**{n:{1:"F"} for n in bn}, "mel":{2:"F"}}, opset_version=17, dynamo=False) dummy = torch.randn(1, vcfg.num_mels, 60) # match vocoder mel count (40 for snake_8k40, 80 otherwise) torch.onnx.export(vm, dummy, str(OUT/"vocoder.onnx"), input_names=["mel"], output_names=["wav"], dynamic_axes={"mel":{2:"frames"},"wav":{2:"samples"}}, opset_version=17, dynamo=False) # full-pipeline ONNX parity sA = ort.InferenceSession(str(OUT/"acoustic_encoder.onnx"), providers=["CPUExecutionProvider"]) sB = ort.InferenceSession(str(OUT/"acoustic_decoder.onnx"), providers=["CPUExecutionProvider"]) sV = ort.InferenceSession(str(OUT/"vocoder.onnx"), providers=["CPUExecutionProvider"]) oc, od, op = sA.run(None, {"phone":phone.numpy(),"tone":tone.numpy(),"lang":lang.numpy(),"speaker":spk.numpy()}) reg2 = host_regulate(oc, od, op, cfg.abs_frame_bins, cfg.max_frames) feeds = {n:(reg2[n].astype(np.float32) if reg2[n].dtype!=bool else reg2[n]) for n in bn} feeds["abs_pos"] = reg2["abs_pos"].astype(np.int64) mel_onnx = sB.run(None, feeds)[0] wav_onnx = sV.run(None, {"mel": mel_onnx.astype(np.float32)})[0] with torch.inference_mode(): wav_ref = vm(mel_ref).numpy() n = min(wav_ref.shape[-1], wav_onnx.shape[-1]) print(f"FULL-PIPELINE wav parity max_abs_diff={float(np.abs(wav_ref[...,:n]-wav_onnx[...,:n]).max()):.2e}") # save metadata for the Nano runtime json.dump({"sample_rate":cfg.sample_rate,"abs_frame_bins":cfg.abs_frame_bins,"max_frames":cfg.max_frames, "hop_size":vcfg.hop_size,"n_mels":cfg.n_mels,"use_frame_pitch":cfg.use_frame_pitch}, open(OUT/"meta.json","w"), indent=1) print("sizes(KB):", {f.name: f.stat().st_size//1024 for f in OUT.glob("*.onnx")}) print("EXPORT_OK", OUT) if __name__ == "__main__": main()