| |
| """Export a trained zh-TW/en 8k Inflect-Nano (acoustic + snake_8k vocoder) to ONNX. |
| Config-driven from the train() checkpoints. FastSpeech split: |
| encoder.onnx -> numpy host_regulate -> decoder.onnx -> vocoder.onnx. |
| Validates full-pipeline parity vs torch. Run in moss-train-venv.""" |
| from __future__ import annotations |
| import argparse, sys, math, json |
| from pathlib import Path |
| import numpy as np, torch |
|
|
| REPO = "/tmp/inflect-nano" |
| sys.path.insert(0, REPO) |
| from inflect_nano.acoustic import MicroFastSpeech, MicroFastSpeechConfig |
| from inflect_nano.vocoder import HifiGanGenerator, make_config |
|
|
|
|
| class EncoderHead(torch.nn.Module): |
| def __init__(self, m): super().__init__(); self.m = m |
| def forward(self, phone, tone, lang, speaker): |
| m = self.m |
| tok = torch.ones_like(phone, dtype=torch.bool) |
| enc = m.encode(phone, tone, lang, speaker, tok) |
| log_dur, energy, bright, pitch = m.predict_prosody(enc, tok) |
| dur = (torch.exp(log_dur) - 1.0).clamp(0, 80).round().clamp_min(1).long() |
| cond = enc + m.energy_proj(energy.unsqueeze(-1)) + m.bright_proj(bright.unsqueeze(-1)) |
| pitch = torch.stack([pitch[..., 0], pitch[..., 1].clamp(0, 1)], dim=-1) |
| return cond, dur, pitch |
|
|
|
|
| class DecoderHead(torch.nn.Module): |
| def __init__(self, m): super().__init__(); self.m = m |
| def forward(self, frames, frame_meta, local_ctx_raw, abs_pos, pitch_frame, frame_mask): |
| m = self.m |
| x = frames + m.frame_proj(frame_meta) + m.local_ctx(local_ctx_raw) |
| x = x + m.abs_frame(abs_pos) |
| if m.cfg.use_frame_pitch: |
| x = x + m.pitch_proj(pitch_frame) |
| for blk in m.decoder: |
| x = blk(x, frame_mask) |
| x = x + m.frame_gru(x)[0] |
| mel = m.mel_head(x).transpose(1, 2) |
| return mel + m.cfg.postnet_scale * m.postnet(mel) |
|
|
|
|
| def host_regulate(cond, dur, pitch, abs_bins, max_frames): |
| c = cond[0]; d = dur[0].astype(np.int64); d[d < 0] = 0 |
| T, H = c.shape |
| frames = np.repeat(c, d, axis=0); F = frames.shape[0] |
| tok = np.repeat(np.arange(T), d); starts = np.cumsum(d) - d |
| within = np.arange(F) - starts[tok]; dpf = d[tok].astype(np.float32) |
| rel = (within / np.maximum(dpf - 1, 1)).astype(np.float32) |
| tc = max(1, int((d > 0).sum())); token_pos = (tok / max(1, tc - 1)).astype(np.float32) |
| ld = (np.log1p(dpf) / 6.0).astype(np.float32); center = 1.0 - np.abs(rel * 2 - 1) |
| fm = np.stack([rel, 1 - rel, center, np.sin(rel*np.pi), np.cos(rel*np.pi), token_pos, ld, dpf/40.0], -1).astype(np.float32) |
| prev = np.concatenate([c[:1], c[:-1]], 0); nxt = np.concatenate([c[1:], c[-1:]], 0) |
| lc = np.repeat(np.concatenate([prev, c, nxt], -1), d, axis=0).astype(np.float32) |
| pos = np.arange(F); abs_pos = np.minimum(pos*abs_bins//max(1, max_frames), abs_bins-1).astype(np.int64) |
| pf = np.repeat(pitch[0], d, axis=0).astype(np.float32) |
| return {"frames": frames[None].astype(np.float32), "frame_meta": fm[None], "local_ctx_raw": lc[None], |
| "abs_pos": abs_pos[None], "pitch_frame": pf[None], "frame_mask": np.ones((1, F), bool)} |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--acoustic-ckpt", required=True) |
| ap.add_argument("--vocoder-ckpt", required=True) |
| ap.add_argument("--out-dir", required=True) |
| ap.add_argument("--symbol-table", default="/home/luigi/jetson-tts/mossnano/zhtw8k/symbol_table.json") |
| args = ap.parse_args() |
| import onnxruntime as ort |
| OUT = Path(args.out_dir); OUT.mkdir(parents=True, exist_ok=True) |
| dev = torch.device("cpu") |
|
|
| ac = torch.load(args.acoustic_ckpt, map_location=dev, weights_only=False) |
| cfg = MicroFastSpeechConfig(**ac["config"]) |
| m = MicroFastSpeech(cfg); m.load_state_dict(ac["model"]); m.eval() |
| enc, dec = EncoderHead(m).eval(), DecoderHead(m).eval() |
| print(f"acoustic: sr={cfg.sample_rate} vocab={cfg.vocab_size} tone={cfg.tone_size} lang={cfg.lang_size} " |
| f"abs_bins={cfg.abs_frame_bins} max_frames={cfg.max_frames}") |
|
|
| vc = torch.load(args.vocoder_ckpt, map_location=dev, weights_only=False) |
| vcfg = make_config(vc["config"]["variant"]) |
| vm = HifiGanGenerator(vcfg); vm.load_state_dict(vc["generator"]); vm.remove_weight_norm(); vm.eval() |
| assert vcfg.sample_rate == cfg.sample_rate |
|
|
| |
| T = 40 |
| g = torch.Generator().manual_seed(0) |
| phone = torch.randint(1, min(80, cfg.vocab_size), (1, T), generator=g) |
| tone = torch.randint(0, cfg.tone_size, (1, T), generator=g) |
| lang = torch.randint(0, cfg.lang_size, (1, T), generator=g) |
| spk = torch.zeros(1, dtype=torch.long) |
|
|
| with torch.no_grad(): |
| cond, dur, pitch = enc(phone, tone, lang, spk) |
| reg = host_regulate(cond.numpy(), dur.numpy(), pitch.numpy(), cfg.abs_frame_bins, cfg.max_frames) |
| bt = tuple(torch.from_numpy(reg[k]).clone() for k in ["frames","frame_meta","local_ctx_raw","abs_pos","pitch_frame","frame_mask"]) |
| mel_split = dec(*bt) |
| mel_ref = m.infer(phone, tone, lang, spk) |
| print(f"mel parity max_abs_diff={float((mel_ref-mel_split).abs().max()):.2e}") |
|
|
| torch.onnx.export(enc, (phone, tone, lang, spk), str(OUT/"acoustic_encoder.onnx"), |
| input_names=["phone","tone","lang","speaker"], output_names=["conditioned","durations","pitch"], |
| dynamic_axes={"phone":{1:"T"},"tone":{1:"T"},"lang":{1:"T"},"conditioned":{1:"T"},"durations":{1:"T"},"pitch":{1:"T"}}, |
| opset_version=17, dynamo=False) |
| bn = ["frames","frame_meta","local_ctx_raw","abs_pos","pitch_frame","frame_mask"] |
| torch.onnx.export(dec, bt, str(OUT/"acoustic_decoder.onnx"), |
| input_names=bn, output_names=["mel"], |
| dynamic_axes={**{n:{1:"F"} for n in bn}, "mel":{2:"F"}}, opset_version=17, dynamo=False) |
| dummy = torch.randn(1, vcfg.num_mels, 60) |
| torch.onnx.export(vm, dummy, str(OUT/"vocoder.onnx"), input_names=["mel"], output_names=["wav"], |
| dynamic_axes={"mel":{2:"frames"},"wav":{2:"samples"}}, opset_version=17, dynamo=False) |
|
|
| |
| sA = ort.InferenceSession(str(OUT/"acoustic_encoder.onnx"), providers=["CPUExecutionProvider"]) |
| sB = ort.InferenceSession(str(OUT/"acoustic_decoder.onnx"), providers=["CPUExecutionProvider"]) |
| sV = ort.InferenceSession(str(OUT/"vocoder.onnx"), providers=["CPUExecutionProvider"]) |
| oc, od, op = sA.run(None, {"phone":phone.numpy(),"tone":tone.numpy(),"lang":lang.numpy(),"speaker":spk.numpy()}) |
| reg2 = host_regulate(oc, od, op, cfg.abs_frame_bins, cfg.max_frames) |
| feeds = {n:(reg2[n].astype(np.float32) if reg2[n].dtype!=bool else reg2[n]) for n in bn} |
| feeds["abs_pos"] = reg2["abs_pos"].astype(np.int64) |
| mel_onnx = sB.run(None, feeds)[0] |
| wav_onnx = sV.run(None, {"mel": mel_onnx.astype(np.float32)})[0] |
| with torch.inference_mode(): wav_ref = vm(mel_ref).numpy() |
| n = min(wav_ref.shape[-1], wav_onnx.shape[-1]) |
| print(f"FULL-PIPELINE wav parity max_abs_diff={float(np.abs(wav_ref[...,:n]-wav_onnx[...,:n]).max()):.2e}") |
| |
| json.dump({"sample_rate":cfg.sample_rate,"abs_frame_bins":cfg.abs_frame_bins,"max_frames":cfg.max_frames, |
| "hop_size":vcfg.hop_size,"n_mels":cfg.n_mels,"use_frame_pitch":cfg.use_frame_pitch}, |
| open(OUT/"meta.json","w"), indent=1) |
| print("sizes(KB):", {f.name: f.stat().st_size//1024 for f in OUT.glob("*.onnx")}) |
| print("EXPORT_OK", OUT) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|