File size: 8,108 Bytes
a37967e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8f98a9
a37967e
 
 
f8f98a9
a37967e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python3
"""v4 aligner: PHONE-LEVEL forced alignment via espeak phoneme-CTC.

Root-cause fix. v1/v3 did CHAR-level CTC then heuristically SPLIT each char span
across its phones -> wrong RELATIVE phone durations -> over-smoothed acoustic.
v4 aligns OUR frontend's phone sequence DIRECTLY:
  each phone -> one IPA token in facebook/wav2vec2-lv-60-espeak-cv-ft's vocab
  -> torchaudio.forced_align(emissions, our_phone_ipa_targets) -> per-phone frames.
No splitting. Durations are contiguous (cover all frames); fit_durations() rescales
to true mel length at train time, so only RELATIVE proportions matter -- and those
are now correct from the phone recognizer.

Non-alignable phones (SP / punctuation / OOV) carry a small relative value.
"""
from __future__ import annotations
import argparse, json, sys
import numpy as np, torch, soundfile as sf, librosa
sys.path.insert(0, "/home/luigi/jetson-tts/mossnano/zhtw8k")

ESPEAK = "facebook/wav2vec2-lv-60-espeak-cv-ft"

# ARPABET (en, stress stripped) -> single IPA token present in the espeak vocab.
ARPA2IPA = {
    'AA':'ɑ','AE':'æ','AH':'ʌ','AO':'ɔ','AW':'aʊ','AY':'aɪ','B':'b','CH':'tʃ',
    'D':'d','DH':'ð','EH':'ɛ','ER':'ɚ','EY':'eɪ','F':'f','G':'ɡ','HH':'h',
    'IH':'ɪ','IY':'i','JH':'dʒ','K':'k','L':'l','M':'m','N':'n','NG':'ŋ',
    'OW':'oʊ','OY':'ɔɪ','P':'p','R':'ɹ','S':'s','SH':'ʃ','T':'t','TH':'θ',
    'UH':'ʊ','UW':'u','V':'v','W':'w','Y':'j','Z':'z','ZH':'ʒ',
}
# bopomofo (zh) -> toneless IPA token in the espeak Mandarin inventory (best-effort).
BOPO2IPA = {
    'ㄅ':'p','ㄆ':'pʰ','ㄇ':'m','ㄈ':'f','ㄉ':'t','ㄊ':'tʰ','ㄋ':'n','ㄌ':'l',
    'ㄍ':'k','ㄎ':'kʰ','ㄏ':'x','ㄐ':'tɕ','ㄑ':'tɕh','ㄒ':'ɕ','ㄓ':'tʃ','ㄔ':'tʃʰ',
    'ㄕ':'ʂ','ㄖ':'ʐ','ㄗ':'ts','ㄘ':'tsh','ㄙ':'s',
    'ㄧ':'i','ㄨ':'u','ㄩ':'y','ㄚ':'a','ㄛ':'o','ㄜ':'ɤ','ㄝ':'e','ㄞ':'ai',
    'ㄟ':'ei','ㄠ':'au','ㄡ':'ou','ㄢ':'a','ㄣ':'ə','ㄤ':'ɑ','ㄥ':'ə','ㄦ':'ɚ',
    'ㄭ':'ɨ',
}


def build_frontend_with_charidx():
    import frontend_bopomofo as F
    import text_norm
    import re
    F._lazy()
    def run(text):
        text = text_norm.normalize(text)
        bopo = F._g2pw(text)[0]
        chars = list(text)
        phones, tones, langs = [], [], []
        i = 0
        while i < len(chars):
            b = bopo[i] if i < len(bopo) else None
            ch = chars[i]
            if b is not None:
                units, tone = F._split_syllable(b)
                for u in units:
                    phones.append(u); tones.append(min(tone,5)); langs.append(0)
                i += 1
            elif re.match(r'[A-Za-z]', ch):
                j = i
                while j < len(chars) and re.match(r"[A-Za-z']", chars[j]): j += 1
                for p in F._g2pen(''.join(chars[i:j])):
                    p = p.strip()
                    if not p: continue
                    if p[-1].isdigit(): st=int(p[-1]); p=p[:-1]
                    else: st=0
                    if p in F.SYM2ID:
                        phones.append(p); tones.append(st); langs.append(1)
                phones.append('SP'); tones.append(0); langs.append(1)
                i = j
            else:
                if ch in F.PUNCT:
                    phones.append(ch); tones.append(0); langs.append(0)
                elif ch.strip()=='' and phones and phones[-1]!='SP':
                    phones.append('SP'); tones.append(0); langs.append(0)
                i += 1
        return text, phones, tones, langs
    return run, F


class Aligner:
    def __init__(self, device):
        from transformers import Wav2Vec2ForCTC
        from huggingface_hub import hf_hub_download
        self.vocab = json.load(open(hf_hub_download(ESPEAK, "vocab.json")))
        self.blank = self.vocab["<pad>"]
        self.model = Wav2Vec2ForCTC.from_pretrained(ESPEAK).to(device).eval()
        self.device = device

    def emissions(self, audio16k):
        iv = torch.from_numpy(audio16k).float().unsqueeze(0).to(self.device)
        with torch.inference_mode():
            logits = self.model(iv).logits[0]              # [T,V]
        return torch.log_softmax(logits, dim=-1).cpu()

    def durations(self, audio16k, phones, langs):
        """Return integer relative duration per phone (contiguous CTC spans)."""
        from torchaudio.functional import forced_align
        emis = self.emissions(audio16k)
        T = emis.shape[0]
        # target = alignable phones mapped to a vocab token id
        tgt_ids, tgt_pi = [], []
        for pi, (p, lg) in enumerate(zip(phones, langs)):
            tok = ARPA2IPA.get(p) if lg == 1 else BOPO2IPA.get(p)
            if tok is not None and tok in self.vocab:
                tgt_ids.append(self.vocab[tok]); tgt_pi.append(pi)
        n = len(phones); dur = [0.0]*n
        if len(tgt_ids) < 1 or len(tgt_ids) > T:
            # fall back: uniform
            for pi in range(n): dur[pi] = 1.0
            return [max(1,int(round(x))) for x in dur], 0
        tokens = torch.tensor([tgt_ids], dtype=torch.int32)
        try:
            aligned, _ = forced_align(emis.unsqueeze(0), tokens, blank=self.blank)
        except Exception:
            for pi in range(n): dur[pi] = 1.0
            return [max(1,int(round(x))) for x in dur], 0
        path = aligned[0].tolist()
        # first frame of each emitted target token (in order)
        starts = []; prev = self.blank
        for fi, tk in enumerate(path):
            if tk != self.blank and tk != prev:
                starts.append(fi)
            prev = tk
        L = min(len(starts), len(tgt_pi))
        # contiguous span per alignable phone: [start_k, start_{k+1})
        for k in range(L):
            s = starts[k]
            e = starts[k+1] if k+1 < len(starts) else T
            dur[tgt_pi[k]] = max(1.0, float(e - s))
        # leading silence -> give to first alignable phone
        if L >= 1 and starts[0] > 0:
            dur[tgt_pi[0]] += starts[0]
        # non-alignable phones (SP / punct / OOV): small relative value
        anchored = [d for d in dur if d > 0]
        base = (sum(anchored)/len(anchored)) if anchored else 4.0
        for pi in range(n):
            if dur[pi] == 0:
                p = phones[pi]
                dur[pi] = base*0.5 if p == 'SP' else base*0.25
        return [max(1,int(round(x))) for x in dur], len(tgt_pi)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--out", required=True)
    ap.add_argument("--limit", type=int, default=0)
    ap.add_argument("--device", default="cuda")
    args = ap.parse_args()
    dev = args.device if torch.cuda.is_available() else "cpu"
    run, F = build_frontend_with_charidx()
    al = Aligner(dev)
    rows = [json.loads(l) for l in open(args.manifest)]
    if args.limit: rows = rows[:args.limit]
    out = open(args.out, "w", encoding="utf-8"); n_ok = 0
    for r in rows:
        wav = r.get("wav") or r.get("target_audio")
        try:
            a, sr = sf.read(wav, dtype="float32")
            if a.ndim > 1: a = a.mean(1)
            a16 = librosa.resample(a, orig_sr=sr, target_sr=16000) if sr != 16000 else a
            _, phones, tones, langs = run(r["text"])
            durs, n_anchor = al.durations(a16, phones, langs)
            phone_ids = [F.SYM2ID.get(p, F.SYM2ID['UNK']) for p in phones]
            out.write(json.dumps({
                "id": r["id"], "text": r["text"], "target_audio": wav,
                "phone_ids": phone_ids, "tone_ids": tones, "lang_ids": langs,
                "hifigan_durations": durs, "speaker_id": 0,
            }, ensure_ascii=False)+"\n")
            n_ok += 1
            if n_ok <= 5:
                print(f"  {r['id']} n_ph={len(phones)} anchored={n_anchor} "
                      f"dur_sum={sum(durs)} dur[:18]={durs[:18]}", flush=True)
        except Exception as e:
            print(f"  [skip {r.get('id')}] {type(e).__name__}: {e}", flush=True)
    out.close(); print(f"DONE aligned {n_ok} rows -> {args.out}")


if __name__ == "__main__":
    main()