File size: 4,005 Bytes
a37967e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python3
"""End-to-end synthesis from text via the exported 8k ONNX pipeline:
text -> bopomofo+arpabet frontend -> ids -> encoder.onnx -> numpy host_regulate
-> decoder.onnx -> vocoder.onnx -> 8kHz wav. Run in moss-train-venv (g2pw+ort).
Used for M1 eval (synthesize zh-TW/en/code-mix test sentences). X-ASR scoring is a
separate step in moss-nano-venv (xasr_offline.py) on the produced wavs."""
from __future__ import annotations
import argparse, json, sys
from pathlib import Path
import numpy as np, soundfile as sf, onnxruntime as ort

ZT = "/home/luigi/jetson-tts/mossnano/zhtw8k"
sys.path.insert(0, ZT)
import frontend_bopomofo as F  # g2pw bopomofo + g2p_en arpabet -> ids


def host_regulate(cond, dur, pitch, abs_bins, max_frames):
    c = cond[0]; d = dur[0].astype(np.int64); d[d < 0] = 0
    T, H = c.shape
    frames = np.repeat(c, d, axis=0); Fn = frames.shape[0]
    tok = np.repeat(np.arange(T), d); starts = np.cumsum(d) - d
    within = np.arange(Fn) - starts[tok]; dpf = d[tok].astype(np.float32)
    rel = (within / np.maximum(dpf - 1, 1)).astype(np.float32)
    tc = max(1, int((d > 0).sum())); token_pos = (tok / max(1, tc - 1)).astype(np.float32)
    ld = (np.log1p(dpf) / 6.0).astype(np.float32); center = 1.0 - np.abs(rel * 2 - 1)
    fm = np.stack([rel, 1 - rel, center, np.sin(rel*np.pi), np.cos(rel*np.pi), token_pos, ld, dpf/40.0], -1).astype(np.float32)
    prev = np.concatenate([c[:1], c[:-1]], 0); nxt = np.concatenate([c[1:], c[-1:]], 0)
    lc = np.repeat(np.concatenate([prev, c, nxt], -1), d, axis=0).astype(np.float32)
    pos = np.arange(Fn); ap = np.minimum(pos*abs_bins//max(1, max_frames), abs_bins-1).astype(np.int64)
    pf = np.repeat(pitch[0], d, axis=0).astype(np.float32)
    return {"frames": frames[None].astype(np.float32), "frame_meta": fm[None], "local_ctx_raw": lc[None],
            "abs_pos": ap[None], "pitch_frame": pf[None], "frame_mask": np.ones((1, Fn), bool)}


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--onnx-dir", required=True)
    ap.add_argument("--out-dir", required=True)
    ap.add_argument("--texts", required=True, help="jsonl with {id,text}")
    args = ap.parse_args()
    meta = json.load(open(f"{args.onnx_dir}/meta.json"))
    so = ort.SessionOptions(); so.intra_op_num_threads = 4
    sA = ort.InferenceSession(f"{args.onnx_dir}/acoustic_encoder.onnx", so, providers=["CPUExecutionProvider"])
    sB = ort.InferenceSession(f"{args.onnx_dir}/acoustic_decoder.onnx", so, providers=["CPUExecutionProvider"])
    sV = ort.InferenceSession(f"{args.onnx_dir}/vocoder.onnx", so, providers=["CPUExecutionProvider"])
    Path(args.out_dir).mkdir(parents=True, exist_ok=True)
    sr = meta["sample_rate"]; bn = ["frames","frame_meta","local_ctx_raw","abs_pos","pitch_frame","frame_mask"]
    rows = [json.loads(l) for l in open(args.texts) if l.strip()]
    out_manifest = open(f"{args.out_dir}/synth.jsonl", "w")
    for r in rows:
        o = F.text_to_ids(r["text"])
        phone = np.array([o["phone_ids"]], np.int64); tone = np.array([o["tone_ids"]], np.int64); lang = np.array([o["lang_ids"]], np.int64)
        spk = np.zeros(1, np.int64)
        cond, dur, pitch = sA.run(None, {"phone": phone, "tone": tone, "lang": lang, "speaker": spk})
        reg = host_regulate(cond, dur, pitch, meta["abs_frame_bins"], meta["max_frames"])
        feeds = {n: (reg[n].astype(np.float32) if reg[n].dtype != bool else reg[n]) for n in bn}
        feeds["abs_pos"] = reg["abs_pos"].astype(np.int64)
        mel = sB.run(None, feeds)[0]
        wav = sV.run(None, {"mel": mel.astype(np.float32)})[0].reshape(-1)
        wp = f"{args.out_dir}/{r['id']}.wav"; sf.write(wp, wav, sr)
        out_manifest.write(json.dumps({"id": r["id"], "text": r["text"], "wav": wp, "dur": round(len(wav)/sr, 2)}, ensure_ascii=False) + "\n")
        print(f"  {r['id']}: {len(wav)/sr:.1f}s -> {wp}")
    out_manifest.close()
    print(f"DONE synth -> {args.out_dir}/synth.jsonl")


if __name__ == "__main__":
    main()