File size: 8,108 Bytes
a37967e f8f98a9 a37967e f8f98a9 a37967e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | #!/usr/bin/env python3
"""v4 aligner: PHONE-LEVEL forced alignment via espeak phoneme-CTC.
Root-cause fix. v1/v3 did CHAR-level CTC then heuristically SPLIT each char span
across its phones -> wrong RELATIVE phone durations -> over-smoothed acoustic.
v4 aligns OUR frontend's phone sequence DIRECTLY:
each phone -> one IPA token in facebook/wav2vec2-lv-60-espeak-cv-ft's vocab
-> torchaudio.forced_align(emissions, our_phone_ipa_targets) -> per-phone frames.
No splitting. Durations are contiguous (cover all frames); fit_durations() rescales
to true mel length at train time, so only RELATIVE proportions matter -- and those
are now correct from the phone recognizer.
Non-alignable phones (SP / punctuation / OOV) carry a small relative value.
"""
from __future__ import annotations
import argparse, json, sys
import numpy as np, torch, soundfile as sf, librosa
sys.path.insert(0, "/home/luigi/jetson-tts/mossnano/zhtw8k")
ESPEAK = "facebook/wav2vec2-lv-60-espeak-cv-ft"
# ARPABET (en, stress stripped) -> single IPA token present in the espeak vocab.
ARPA2IPA = {
'AA':'ɑ','AE':'æ','AH':'ʌ','AO':'ɔ','AW':'aʊ','AY':'aɪ','B':'b','CH':'tʃ',
'D':'d','DH':'ð','EH':'ɛ','ER':'ɚ','EY':'eɪ','F':'f','G':'ɡ','HH':'h',
'IH':'ɪ','IY':'i','JH':'dʒ','K':'k','L':'l','M':'m','N':'n','NG':'ŋ',
'OW':'oʊ','OY':'ɔɪ','P':'p','R':'ɹ','S':'s','SH':'ʃ','T':'t','TH':'θ',
'UH':'ʊ','UW':'u','V':'v','W':'w','Y':'j','Z':'z','ZH':'ʒ',
}
# bopomofo (zh) -> toneless IPA token in the espeak Mandarin inventory (best-effort).
BOPO2IPA = {
'ㄅ':'p','ㄆ':'pʰ','ㄇ':'m','ㄈ':'f','ㄉ':'t','ㄊ':'tʰ','ㄋ':'n','ㄌ':'l',
'ㄍ':'k','ㄎ':'kʰ','ㄏ':'x','ㄐ':'tɕ','ㄑ':'tɕh','ㄒ':'ɕ','ㄓ':'tʃ','ㄔ':'tʃʰ',
'ㄕ':'ʂ','ㄖ':'ʐ','ㄗ':'ts','ㄘ':'tsh','ㄙ':'s',
'ㄧ':'i','ㄨ':'u','ㄩ':'y','ㄚ':'a','ㄛ':'o','ㄜ':'ɤ','ㄝ':'e','ㄞ':'ai',
'ㄟ':'ei','ㄠ':'au','ㄡ':'ou','ㄢ':'a','ㄣ':'ə','ㄤ':'ɑ','ㄥ':'ə','ㄦ':'ɚ',
'ㄭ':'ɨ',
}
def build_frontend_with_charidx():
import frontend_bopomofo as F
import text_norm
import re
F._lazy()
def run(text):
text = text_norm.normalize(text)
bopo = F._g2pw(text)[0]
chars = list(text)
phones, tones, langs = [], [], []
i = 0
while i < len(chars):
b = bopo[i] if i < len(bopo) else None
ch = chars[i]
if b is not None:
units, tone = F._split_syllable(b)
for u in units:
phones.append(u); tones.append(min(tone,5)); langs.append(0)
i += 1
elif re.match(r'[A-Za-z]', ch):
j = i
while j < len(chars) and re.match(r"[A-Za-z']", chars[j]): j += 1
for p in F._g2pen(''.join(chars[i:j])):
p = p.strip()
if not p: continue
if p[-1].isdigit(): st=int(p[-1]); p=p[:-1]
else: st=0
if p in F.SYM2ID:
phones.append(p); tones.append(st); langs.append(1)
phones.append('SP'); tones.append(0); langs.append(1)
i = j
else:
if ch in F.PUNCT:
phones.append(ch); tones.append(0); langs.append(0)
elif ch.strip()=='' and phones and phones[-1]!='SP':
phones.append('SP'); tones.append(0); langs.append(0)
i += 1
return text, phones, tones, langs
return run, F
class Aligner:
def __init__(self, device):
from transformers import Wav2Vec2ForCTC
from huggingface_hub import hf_hub_download
self.vocab = json.load(open(hf_hub_download(ESPEAK, "vocab.json")))
self.blank = self.vocab["<pad>"]
self.model = Wav2Vec2ForCTC.from_pretrained(ESPEAK).to(device).eval()
self.device = device
def emissions(self, audio16k):
iv = torch.from_numpy(audio16k).float().unsqueeze(0).to(self.device)
with torch.inference_mode():
logits = self.model(iv).logits[0] # [T,V]
return torch.log_softmax(logits, dim=-1).cpu()
def durations(self, audio16k, phones, langs):
"""Return integer relative duration per phone (contiguous CTC spans)."""
from torchaudio.functional import forced_align
emis = self.emissions(audio16k)
T = emis.shape[0]
# target = alignable phones mapped to a vocab token id
tgt_ids, tgt_pi = [], []
for pi, (p, lg) in enumerate(zip(phones, langs)):
tok = ARPA2IPA.get(p) if lg == 1 else BOPO2IPA.get(p)
if tok is not None and tok in self.vocab:
tgt_ids.append(self.vocab[tok]); tgt_pi.append(pi)
n = len(phones); dur = [0.0]*n
if len(tgt_ids) < 1 or len(tgt_ids) > T:
# fall back: uniform
for pi in range(n): dur[pi] = 1.0
return [max(1,int(round(x))) for x in dur], 0
tokens = torch.tensor([tgt_ids], dtype=torch.int32)
try:
aligned, _ = forced_align(emis.unsqueeze(0), tokens, blank=self.blank)
except Exception:
for pi in range(n): dur[pi] = 1.0
return [max(1,int(round(x))) for x in dur], 0
path = aligned[0].tolist()
# first frame of each emitted target token (in order)
starts = []; prev = self.blank
for fi, tk in enumerate(path):
if tk != self.blank and tk != prev:
starts.append(fi)
prev = tk
L = min(len(starts), len(tgt_pi))
# contiguous span per alignable phone: [start_k, start_{k+1})
for k in range(L):
s = starts[k]
e = starts[k+1] if k+1 < len(starts) else T
dur[tgt_pi[k]] = max(1.0, float(e - s))
# leading silence -> give to first alignable phone
if L >= 1 and starts[0] > 0:
dur[tgt_pi[0]] += starts[0]
# non-alignable phones (SP / punct / OOV): small relative value
anchored = [d for d in dur if d > 0]
base = (sum(anchored)/len(anchored)) if anchored else 4.0
for pi in range(n):
if dur[pi] == 0:
p = phones[pi]
dur[pi] = base*0.5 if p == 'SP' else base*0.25
return [max(1,int(round(x))) for x in dur], len(tgt_pi)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--manifest", required=True)
ap.add_argument("--out", required=True)
ap.add_argument("--limit", type=int, default=0)
ap.add_argument("--device", default="cuda")
args = ap.parse_args()
dev = args.device if torch.cuda.is_available() else "cpu"
run, F = build_frontend_with_charidx()
al = Aligner(dev)
rows = [json.loads(l) for l in open(args.manifest)]
if args.limit: rows = rows[:args.limit]
out = open(args.out, "w", encoding="utf-8"); n_ok = 0
for r in rows:
wav = r.get("wav") or r.get("target_audio")
try:
a, sr = sf.read(wav, dtype="float32")
if a.ndim > 1: a = a.mean(1)
a16 = librosa.resample(a, orig_sr=sr, target_sr=16000) if sr != 16000 else a
_, phones, tones, langs = run(r["text"])
durs, n_anchor = al.durations(a16, phones, langs)
phone_ids = [F.SYM2ID.get(p, F.SYM2ID['UNK']) for p in phones]
out.write(json.dumps({
"id": r["id"], "text": r["text"], "target_audio": wav,
"phone_ids": phone_ids, "tone_ids": tones, "lang_ids": langs,
"hifigan_durations": durs, "speaker_id": 0,
}, ensure_ascii=False)+"\n")
n_ok += 1
if n_ok <= 5:
print(f" {r['id']} n_ph={len(phones)} anchored={n_anchor} "
f"dur_sum={sum(durs)} dur[:18]={durs[:18]}", flush=True)
except Exception as e:
print(f" [skip {r.get('id')}] {type(e).__name__}: {e}", flush=True)
out.close(); print(f"DONE aligned {n_ok} rows -> {args.out}")
if __name__ == "__main__":
main()
|