File size: 8,122 Bytes
a37967e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1261e7b
 
a37967e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1261e7b
 
 
 
 
 
 
a37967e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python3
"""Export a trained zh-TW/en 8k Inflect-Nano (acoustic + snake_8k vocoder) to ONNX.
Config-driven from the train() checkpoints. FastSpeech split:
  encoder.onnx -> numpy host_regulate -> decoder.onnx -> vocoder.onnx.
Validates full-pipeline parity vs torch. Run in moss-train-venv."""
from __future__ import annotations
import argparse, sys, math, json
from pathlib import Path
import numpy as np, torch

REPO = "/tmp/inflect-nano"
sys.path.insert(0, REPO)
from inflect_nano.acoustic import MicroFastSpeech, MicroFastSpeechConfig
from inflect_nano.vocoder import HifiGanGenerator, make_config


class EncoderHead(torch.nn.Module):
    def __init__(self, m): super().__init__(); self.m = m
    def forward(self, phone, tone, lang, speaker):
        m = self.m
        tok = torch.ones_like(phone, dtype=torch.bool)
        enc = m.encode(phone, tone, lang, speaker, tok)
        log_dur, energy, bright, pitch = m.predict_prosody(enc, tok)
        dur = (torch.exp(log_dur) - 1.0).clamp(0, 80).round().clamp_min(1).long()
        cond = enc + m.energy_proj(energy.unsqueeze(-1)) + m.bright_proj(bright.unsqueeze(-1))
        pitch = torch.stack([pitch[..., 0], pitch[..., 1].clamp(0, 1)], dim=-1)
        return cond, dur, pitch


class DecoderHead(torch.nn.Module):
    def __init__(self, m): super().__init__(); self.m = m
    def forward(self, frames, frame_meta, local_ctx_raw, abs_pos, pitch_frame, frame_mask):
        m = self.m
        x = frames + m.frame_proj(frame_meta) + m.local_ctx(local_ctx_raw)
        x = x + m.abs_frame(abs_pos)
        if m.cfg.use_frame_pitch:
            refined = m.refine_frame_pitch(frames, frame_meta, pitch_frame)
            x = x + m.pitch_proj(refined)
        for blk in m.decoder:
            x = blk(x, frame_mask)
        x = x + m.frame_gru(x)[0]
        mel = m.mel_head(x).transpose(1, 2)
        return mel + m.cfg.postnet_scale * m.postnet(mel)


def host_regulate(cond, dur, pitch, abs_bins, max_frames):
    c = cond[0]; d = dur[0].astype(np.int64); d[d < 0] = 0
    T, H = c.shape
    frames = np.repeat(c, d, axis=0); F = frames.shape[0]
    tok = np.repeat(np.arange(T), d); starts = np.cumsum(d) - d
    within = np.arange(F) - starts[tok]; dpf = d[tok].astype(np.float32)
    rel = (within / np.maximum(dpf - 1, 1)).astype(np.float32)
    tc = max(1, int((d > 0).sum())); token_pos = (tok / max(1, tc - 1)).astype(np.float32)
    ld = (np.log1p(dpf) / 6.0).astype(np.float32); center = 1.0 - np.abs(rel * 2 - 1)
    fm = np.stack([rel, 1 - rel, center, np.sin(rel*np.pi), np.cos(rel*np.pi), token_pos, ld, dpf/40.0], -1).astype(np.float32)
    prev = np.concatenate([c[:1], c[:-1]], 0); nxt = np.concatenate([c[1:], c[-1:]], 0)
    lc = np.repeat(np.concatenate([prev, c, nxt], -1), d, axis=0).astype(np.float32)
    pos = np.arange(F); abs_pos = np.minimum(pos*abs_bins//max(1, max_frames), abs_bins-1).astype(np.int64)
    pf = np.repeat(pitch[0], d, axis=0).astype(np.float32)
    return {"frames": frames[None].astype(np.float32), "frame_meta": fm[None], "local_ctx_raw": lc[None],
            "abs_pos": abs_pos[None], "pitch_frame": pf[None], "frame_mask": np.ones((1, F), bool)}


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--acoustic-ckpt", required=True)
    ap.add_argument("--vocoder-ckpt", required=True)
    ap.add_argument("--out-dir", required=True)
    ap.add_argument("--symbol-table", default="/home/luigi/jetson-tts/mossnano/zhtw8k/symbol_table.json")
    args = ap.parse_args()
    import onnxruntime as ort
    OUT = Path(args.out_dir); OUT.mkdir(parents=True, exist_ok=True)
    dev = torch.device("cpu")

    ac = torch.load(args.acoustic_ckpt, map_location=dev, weights_only=False)
    cfg = MicroFastSpeechConfig(**ac["config"])
    m = MicroFastSpeech(cfg); m.load_state_dict(ac["model"], strict=False); m.eval()
    # The group-duration planner uses a non-ONNX-able host loop and only adjusts inference-time
    # durations (the mel decoder is trained on GT durations). Disable it so the exported encoder's
    # plain-duration path (which keeps the contextual duration-delta) matches m.infer() for parity.
    if getattr(m, "group_duration_delta", None) is not None:
        m.group_duration_delta = None
        print("note: group_duration_planner disabled at export (host-loop; plain durations used)")
    enc, dec = EncoderHead(m).eval(), DecoderHead(m).eval()
    print(f"acoustic: sr={cfg.sample_rate} vocab={cfg.vocab_size} tone={cfg.tone_size} lang={cfg.lang_size} "
          f"abs_bins={cfg.abs_frame_bins} max_frames={cfg.max_frames}")

    vc = torch.load(args.vocoder_ckpt, map_location=dev, weights_only=False)
    vcfg = make_config(vc["config"]["variant"])
    vm = HifiGanGenerator(vcfg); vm.load_state_dict(vc["generator"]); vm.remove_weight_norm(); vm.eval()
    assert vcfg.sample_rate == cfg.sample_rate

    # sample input: a short valid id sequence (plumbing test)
    T = 40
    g = torch.Generator().manual_seed(0)
    phone = torch.randint(1, min(80, cfg.vocab_size), (1, T), generator=g)
    tone = torch.randint(0, cfg.tone_size, (1, T), generator=g)
    lang = torch.randint(0, cfg.lang_size, (1, T), generator=g)
    spk = torch.zeros(1, dtype=torch.long)

    with torch.no_grad():
        cond, dur, pitch = enc(phone, tone, lang, spk)
        reg = host_regulate(cond.numpy(), dur.numpy(), pitch.numpy(), cfg.abs_frame_bins, cfg.max_frames)
        bt = tuple(torch.from_numpy(reg[k]).clone() for k in ["frames","frame_meta","local_ctx_raw","abs_pos","pitch_frame","frame_mask"])
        mel_split = dec(*bt)
        mel_ref = m.infer(phone, tone, lang, spk)
    print(f"mel parity max_abs_diff={float((mel_ref-mel_split).abs().max()):.2e}")

    torch.onnx.export(enc, (phone, tone, lang, spk), str(OUT/"acoustic_encoder.onnx"),
        input_names=["phone","tone","lang","speaker"], output_names=["conditioned","durations","pitch"],
        dynamic_axes={"phone":{1:"T"},"tone":{1:"T"},"lang":{1:"T"},"conditioned":{1:"T"},"durations":{1:"T"},"pitch":{1:"T"}},
        opset_version=17, dynamo=False)
    bn = ["frames","frame_meta","local_ctx_raw","abs_pos","pitch_frame","frame_mask"]
    torch.onnx.export(dec, bt, str(OUT/"acoustic_decoder.onnx"),
        input_names=bn, output_names=["mel"],
        dynamic_axes={**{n:{1:"F"} for n in bn}, "mel":{2:"F"}}, opset_version=17, dynamo=False)
    dummy = torch.randn(1, vcfg.num_mels, 60)  # match vocoder mel count (40 for snake_8k40, 80 otherwise)
    torch.onnx.export(vm, dummy, str(OUT/"vocoder.onnx"), input_names=["mel"], output_names=["wav"],
        dynamic_axes={"mel":{2:"frames"},"wav":{2:"samples"}}, opset_version=17, dynamo=False)

    # full-pipeline ONNX parity
    sA = ort.InferenceSession(str(OUT/"acoustic_encoder.onnx"), providers=["CPUExecutionProvider"])
    sB = ort.InferenceSession(str(OUT/"acoustic_decoder.onnx"), providers=["CPUExecutionProvider"])
    sV = ort.InferenceSession(str(OUT/"vocoder.onnx"), providers=["CPUExecutionProvider"])
    oc, od, op = sA.run(None, {"phone":phone.numpy(),"tone":tone.numpy(),"lang":lang.numpy(),"speaker":spk.numpy()})
    reg2 = host_regulate(oc, od, op, cfg.abs_frame_bins, cfg.max_frames)
    feeds = {n:(reg2[n].astype(np.float32) if reg2[n].dtype!=bool else reg2[n]) for n in bn}
    feeds["abs_pos"] = reg2["abs_pos"].astype(np.int64)
    mel_onnx = sB.run(None, feeds)[0]
    wav_onnx = sV.run(None, {"mel": mel_onnx.astype(np.float32)})[0]
    with torch.inference_mode(): wav_ref = vm(mel_ref).numpy()
    n = min(wav_ref.shape[-1], wav_onnx.shape[-1])
    print(f"FULL-PIPELINE wav parity max_abs_diff={float(np.abs(wav_ref[...,:n]-wav_onnx[...,:n]).max()):.2e}")
    # save metadata for the Nano runtime
    json.dump({"sample_rate":cfg.sample_rate,"abs_frame_bins":cfg.abs_frame_bins,"max_frames":cfg.max_frames,
               "hop_size":vcfg.hop_size,"n_mels":cfg.n_mels,"use_frame_pitch":cfg.use_frame_pitch},
              open(OUT/"meta.json","w"), indent=1)
    print("sizes(KB):", {f.name: f.stat().st_size//1024 for f in OUT.glob("*.onnx")})
    print("EXPORT_OK", OUT)


if __name__ == "__main__":
    main()