PrimeTTS / scripts /export_8k.py
Luigi's picture
PrimeTTS: full training pipeline + weights (fine-tune of Inflect-Nano-v1)
a37967e verified
Raw
History Blame Contribute Delete
7.54 kB
#!/usr/bin/env python3
"""Export a trained zh-TW/en 8k Inflect-Nano (acoustic + snake_8k vocoder) to ONNX.
Config-driven from the train() checkpoints. FastSpeech split:
encoder.onnx -> numpy host_regulate -> decoder.onnx -> vocoder.onnx.
Validates full-pipeline parity vs torch. Run in moss-train-venv."""
from __future__ import annotations
import argparse, sys, math, json
from pathlib import Path
import numpy as np, torch
REPO = "/tmp/inflect-nano"
sys.path.insert(0, REPO)
from inflect_nano.acoustic import MicroFastSpeech, MicroFastSpeechConfig
from inflect_nano.vocoder import HifiGanGenerator, make_config
class EncoderHead(torch.nn.Module):
def __init__(self, m): super().__init__(); self.m = m
def forward(self, phone, tone, lang, speaker):
m = self.m
tok = torch.ones_like(phone, dtype=torch.bool)
enc = m.encode(phone, tone, lang, speaker, tok)
log_dur, energy, bright, pitch = m.predict_prosody(enc, tok)
dur = (torch.exp(log_dur) - 1.0).clamp(0, 80).round().clamp_min(1).long()
cond = enc + m.energy_proj(energy.unsqueeze(-1)) + m.bright_proj(bright.unsqueeze(-1))
pitch = torch.stack([pitch[..., 0], pitch[..., 1].clamp(0, 1)], dim=-1)
return cond, dur, pitch
class DecoderHead(torch.nn.Module):
def __init__(self, m): super().__init__(); self.m = m
def forward(self, frames, frame_meta, local_ctx_raw, abs_pos, pitch_frame, frame_mask):
m = self.m
x = frames + m.frame_proj(frame_meta) + m.local_ctx(local_ctx_raw)
x = x + m.abs_frame(abs_pos)
if m.cfg.use_frame_pitch:
x = x + m.pitch_proj(pitch_frame)
for blk in m.decoder:
x = blk(x, frame_mask)
x = x + m.frame_gru(x)[0]
mel = m.mel_head(x).transpose(1, 2)
return mel + m.cfg.postnet_scale * m.postnet(mel)
def host_regulate(cond, dur, pitch, abs_bins, max_frames):
c = cond[0]; d = dur[0].astype(np.int64); d[d < 0] = 0
T, H = c.shape
frames = np.repeat(c, d, axis=0); F = frames.shape[0]
tok = np.repeat(np.arange(T), d); starts = np.cumsum(d) - d
within = np.arange(F) - starts[tok]; dpf = d[tok].astype(np.float32)
rel = (within / np.maximum(dpf - 1, 1)).astype(np.float32)
tc = max(1, int((d > 0).sum())); token_pos = (tok / max(1, tc - 1)).astype(np.float32)
ld = (np.log1p(dpf) / 6.0).astype(np.float32); center = 1.0 - np.abs(rel * 2 - 1)
fm = np.stack([rel, 1 - rel, center, np.sin(rel*np.pi), np.cos(rel*np.pi), token_pos, ld, dpf/40.0], -1).astype(np.float32)
prev = np.concatenate([c[:1], c[:-1]], 0); nxt = np.concatenate([c[1:], c[-1:]], 0)
lc = np.repeat(np.concatenate([prev, c, nxt], -1), d, axis=0).astype(np.float32)
pos = np.arange(F); abs_pos = np.minimum(pos*abs_bins//max(1, max_frames), abs_bins-1).astype(np.int64)
pf = np.repeat(pitch[0], d, axis=0).astype(np.float32)
return {"frames": frames[None].astype(np.float32), "frame_meta": fm[None], "local_ctx_raw": lc[None],
"abs_pos": abs_pos[None], "pitch_frame": pf[None], "frame_mask": np.ones((1, F), bool)}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--acoustic-ckpt", required=True)
ap.add_argument("--vocoder-ckpt", required=True)
ap.add_argument("--out-dir", required=True)
ap.add_argument("--symbol-table", default="/home/luigi/jetson-tts/mossnano/zhtw8k/symbol_table.json")
args = ap.parse_args()
import onnxruntime as ort
OUT = Path(args.out_dir); OUT.mkdir(parents=True, exist_ok=True)
dev = torch.device("cpu")
ac = torch.load(args.acoustic_ckpt, map_location=dev, weights_only=False)
cfg = MicroFastSpeechConfig(**ac["config"])
m = MicroFastSpeech(cfg); m.load_state_dict(ac["model"]); m.eval()
enc, dec = EncoderHead(m).eval(), DecoderHead(m).eval()
print(f"acoustic: sr={cfg.sample_rate} vocab={cfg.vocab_size} tone={cfg.tone_size} lang={cfg.lang_size} "
f"abs_bins={cfg.abs_frame_bins} max_frames={cfg.max_frames}")
vc = torch.load(args.vocoder_ckpt, map_location=dev, weights_only=False)
vcfg = make_config(vc["config"]["variant"])
vm = HifiGanGenerator(vcfg); vm.load_state_dict(vc["generator"]); vm.remove_weight_norm(); vm.eval()
assert vcfg.sample_rate == cfg.sample_rate
# sample input: a short valid id sequence (plumbing test)
T = 40
g = torch.Generator().manual_seed(0)
phone = torch.randint(1, min(80, cfg.vocab_size), (1, T), generator=g)
tone = torch.randint(0, cfg.tone_size, (1, T), generator=g)
lang = torch.randint(0, cfg.lang_size, (1, T), generator=g)
spk = torch.zeros(1, dtype=torch.long)
with torch.no_grad():
cond, dur, pitch = enc(phone, tone, lang, spk)
reg = host_regulate(cond.numpy(), dur.numpy(), pitch.numpy(), cfg.abs_frame_bins, cfg.max_frames)
bt = tuple(torch.from_numpy(reg[k]).clone() for k in ["frames","frame_meta","local_ctx_raw","abs_pos","pitch_frame","frame_mask"])
mel_split = dec(*bt)
mel_ref = m.infer(phone, tone, lang, spk)
print(f"mel parity max_abs_diff={float((mel_ref-mel_split).abs().max()):.2e}")
torch.onnx.export(enc, (phone, tone, lang, spk), str(OUT/"acoustic_encoder.onnx"),
input_names=["phone","tone","lang","speaker"], output_names=["conditioned","durations","pitch"],
dynamic_axes={"phone":{1:"T"},"tone":{1:"T"},"lang":{1:"T"},"conditioned":{1:"T"},"durations":{1:"T"},"pitch":{1:"T"}},
opset_version=17, dynamo=False)
bn = ["frames","frame_meta","local_ctx_raw","abs_pos","pitch_frame","frame_mask"]
torch.onnx.export(dec, bt, str(OUT/"acoustic_decoder.onnx"),
input_names=bn, output_names=["mel"],
dynamic_axes={**{n:{1:"F"} for n in bn}, "mel":{2:"F"}}, opset_version=17, dynamo=False)
dummy = torch.randn(1, vcfg.num_mels, 60) # match vocoder mel count (40 for snake_8k40, 80 otherwise)
torch.onnx.export(vm, dummy, str(OUT/"vocoder.onnx"), input_names=["mel"], output_names=["wav"],
dynamic_axes={"mel":{2:"frames"},"wav":{2:"samples"}}, opset_version=17, dynamo=False)
# full-pipeline ONNX parity
sA = ort.InferenceSession(str(OUT/"acoustic_encoder.onnx"), providers=["CPUExecutionProvider"])
sB = ort.InferenceSession(str(OUT/"acoustic_decoder.onnx"), providers=["CPUExecutionProvider"])
sV = ort.InferenceSession(str(OUT/"vocoder.onnx"), providers=["CPUExecutionProvider"])
oc, od, op = sA.run(None, {"phone":phone.numpy(),"tone":tone.numpy(),"lang":lang.numpy(),"speaker":spk.numpy()})
reg2 = host_regulate(oc, od, op, cfg.abs_frame_bins, cfg.max_frames)
feeds = {n:(reg2[n].astype(np.float32) if reg2[n].dtype!=bool else reg2[n]) for n in bn}
feeds["abs_pos"] = reg2["abs_pos"].astype(np.int64)
mel_onnx = sB.run(None, feeds)[0]
wav_onnx = sV.run(None, {"mel": mel_onnx.astype(np.float32)})[0]
with torch.inference_mode(): wav_ref = vm(mel_ref).numpy()
n = min(wav_ref.shape[-1], wav_onnx.shape[-1])
print(f"FULL-PIPELINE wav parity max_abs_diff={float(np.abs(wav_ref[...,:n]-wav_onnx[...,:n]).max()):.2e}")
# save metadata for the Nano runtime
json.dump({"sample_rate":cfg.sample_rate,"abs_frame_bins":cfg.abs_frame_bins,"max_frames":cfg.max_frames,
"hop_size":vcfg.hop_size,"n_mels":cfg.n_mels,"use_frame_pitch":cfg.use_frame_pitch},
open(OUT/"meta.json","w"), indent=1)
print("sizes(KB):", {f.name: f.stat().st_size//1024 for f in OUT.glob("*.onnx")})
print("EXPORT_OK", OUT)
if __name__ == "__main__":
main()