PrimeTTS / scripts /align_durations_v4.py
Luigi's picture
fix: aligner uses text_norm.normalize (entity normalizer)
f8f98a9 verified
Raw
History Blame Contribute Delete
8.11 kB
#!/usr/bin/env python3
"""v4 aligner: PHONE-LEVEL forced alignment via espeak phoneme-CTC.
Root-cause fix. v1/v3 did CHAR-level CTC then heuristically SPLIT each char span
across its phones -> wrong RELATIVE phone durations -> over-smoothed acoustic.
v4 aligns OUR frontend's phone sequence DIRECTLY:
each phone -> one IPA token in facebook/wav2vec2-lv-60-espeak-cv-ft's vocab
-> torchaudio.forced_align(emissions, our_phone_ipa_targets) -> per-phone frames.
No splitting. Durations are contiguous (cover all frames); fit_durations() rescales
to true mel length at train time, so only RELATIVE proportions matter -- and those
are now correct from the phone recognizer.
Non-alignable phones (SP / punctuation / OOV) carry a small relative value.
"""
from __future__ import annotations
import argparse, json, sys
import numpy as np, torch, soundfile as sf, librosa
sys.path.insert(0, "/home/luigi/jetson-tts/mossnano/zhtw8k")
ESPEAK = "facebook/wav2vec2-lv-60-espeak-cv-ft"
# ARPABET (en, stress stripped) -> single IPA token present in the espeak vocab.
ARPA2IPA = {
'AA':'ɑ','AE':'æ','AH':'ʌ','AO':'ɔ','AW':'aʊ','AY':'aɪ','B':'b','CH':'tʃ',
'D':'d','DH':'ð','EH':'ɛ','ER':'ɚ','EY':'eɪ','F':'f','G':'ɡ','HH':'h',
'IH':'ɪ','IY':'i','JH':'dʒ','K':'k','L':'l','M':'m','N':'n','NG':'ŋ',
'OW':'oʊ','OY':'ɔɪ','P':'p','R':'ɹ','S':'s','SH':'ʃ','T':'t','TH':'θ',
'UH':'ʊ','UW':'u','V':'v','W':'w','Y':'j','Z':'z','ZH':'ʒ',
}
# bopomofo (zh) -> toneless IPA token in the espeak Mandarin inventory (best-effort).
BOPO2IPA = {
'ㄅ':'p','ㄆ':'pʰ','ㄇ':'m','ㄈ':'f','ㄉ':'t','ㄊ':'tʰ','ㄋ':'n','ㄌ':'l',
'ㄍ':'k','ㄎ':'kʰ','ㄏ':'x','ㄐ':'tɕ','ㄑ':'tɕh','ㄒ':'ɕ','ㄓ':'tʃ','ㄔ':'tʃʰ',
'ㄕ':'ʂ','ㄖ':'ʐ','ㄗ':'ts','ㄘ':'tsh','ㄙ':'s',
'ㄧ':'i','ㄨ':'u','ㄩ':'y','ㄚ':'a','ㄛ':'o','ㄜ':'ɤ','ㄝ':'e','ㄞ':'ai',
'ㄟ':'ei','ㄠ':'au','ㄡ':'ou','ㄢ':'a','ㄣ':'ə','ㄤ':'ɑ','ㄥ':'ə','ㄦ':'ɚ',
'ㄭ':'ɨ',
}
def build_frontend_with_charidx():
import frontend_bopomofo as F
import text_norm
import re
F._lazy()
def run(text):
text = text_norm.normalize(text)
bopo = F._g2pw(text)[0]
chars = list(text)
phones, tones, langs = [], [], []
i = 0
while i < len(chars):
b = bopo[i] if i < len(bopo) else None
ch = chars[i]
if b is not None:
units, tone = F._split_syllable(b)
for u in units:
phones.append(u); tones.append(min(tone,5)); langs.append(0)
i += 1
elif re.match(r'[A-Za-z]', ch):
j = i
while j < len(chars) and re.match(r"[A-Za-z']", chars[j]): j += 1
for p in F._g2pen(''.join(chars[i:j])):
p = p.strip()
if not p: continue
if p[-1].isdigit(): st=int(p[-1]); p=p[:-1]
else: st=0
if p in F.SYM2ID:
phones.append(p); tones.append(st); langs.append(1)
phones.append('SP'); tones.append(0); langs.append(1)
i = j
else:
if ch in F.PUNCT:
phones.append(ch); tones.append(0); langs.append(0)
elif ch.strip()=='' and phones and phones[-1]!='SP':
phones.append('SP'); tones.append(0); langs.append(0)
i += 1
return text, phones, tones, langs
return run, F
class Aligner:
def __init__(self, device):
from transformers import Wav2Vec2ForCTC
from huggingface_hub import hf_hub_download
self.vocab = json.load(open(hf_hub_download(ESPEAK, "vocab.json")))
self.blank = self.vocab["<pad>"]
self.model = Wav2Vec2ForCTC.from_pretrained(ESPEAK).to(device).eval()
self.device = device
def emissions(self, audio16k):
iv = torch.from_numpy(audio16k).float().unsqueeze(0).to(self.device)
with torch.inference_mode():
logits = self.model(iv).logits[0] # [T,V]
return torch.log_softmax(logits, dim=-1).cpu()
def durations(self, audio16k, phones, langs):
"""Return integer relative duration per phone (contiguous CTC spans)."""
from torchaudio.functional import forced_align
emis = self.emissions(audio16k)
T = emis.shape[0]
# target = alignable phones mapped to a vocab token id
tgt_ids, tgt_pi = [], []
for pi, (p, lg) in enumerate(zip(phones, langs)):
tok = ARPA2IPA.get(p) if lg == 1 else BOPO2IPA.get(p)
if tok is not None and tok in self.vocab:
tgt_ids.append(self.vocab[tok]); tgt_pi.append(pi)
n = len(phones); dur = [0.0]*n
if len(tgt_ids) < 1 or len(tgt_ids) > T:
# fall back: uniform
for pi in range(n): dur[pi] = 1.0
return [max(1,int(round(x))) for x in dur], 0
tokens = torch.tensor([tgt_ids], dtype=torch.int32)
try:
aligned, _ = forced_align(emis.unsqueeze(0), tokens, blank=self.blank)
except Exception:
for pi in range(n): dur[pi] = 1.0
return [max(1,int(round(x))) for x in dur], 0
path = aligned[0].tolist()
# first frame of each emitted target token (in order)
starts = []; prev = self.blank
for fi, tk in enumerate(path):
if tk != self.blank and tk != prev:
starts.append(fi)
prev = tk
L = min(len(starts), len(tgt_pi))
# contiguous span per alignable phone: [start_k, start_{k+1})
for k in range(L):
s = starts[k]
e = starts[k+1] if k+1 < len(starts) else T
dur[tgt_pi[k]] = max(1.0, float(e - s))
# leading silence -> give to first alignable phone
if L >= 1 and starts[0] > 0:
dur[tgt_pi[0]] += starts[0]
# non-alignable phones (SP / punct / OOV): small relative value
anchored = [d for d in dur if d > 0]
base = (sum(anchored)/len(anchored)) if anchored else 4.0
for pi in range(n):
if dur[pi] == 0:
p = phones[pi]
dur[pi] = base*0.5 if p == 'SP' else base*0.25
return [max(1,int(round(x))) for x in dur], len(tgt_pi)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--manifest", required=True)
ap.add_argument("--out", required=True)
ap.add_argument("--limit", type=int, default=0)
ap.add_argument("--device", default="cuda")
args = ap.parse_args()
dev = args.device if torch.cuda.is_available() else "cpu"
run, F = build_frontend_with_charidx()
al = Aligner(dev)
rows = [json.loads(l) for l in open(args.manifest)]
if args.limit: rows = rows[:args.limit]
out = open(args.out, "w", encoding="utf-8"); n_ok = 0
for r in rows:
wav = r.get("wav") or r.get("target_audio")
try:
a, sr = sf.read(wav, dtype="float32")
if a.ndim > 1: a = a.mean(1)
a16 = librosa.resample(a, orig_sr=sr, target_sr=16000) if sr != 16000 else a
_, phones, tones, langs = run(r["text"])
durs, n_anchor = al.durations(a16, phones, langs)
phone_ids = [F.SYM2ID.get(p, F.SYM2ID['UNK']) for p in phones]
out.write(json.dumps({
"id": r["id"], "text": r["text"], "target_audio": wav,
"phone_ids": phone_ids, "tone_ids": tones, "lang_ids": langs,
"hifigan_durations": durs, "speaker_id": 0,
}, ensure_ascii=False)+"\n")
n_ok += 1
if n_ok <= 5:
print(f" {r['id']} n_ph={len(phones)} anchored={n_anchor} "
f"dur_sum={sum(durs)} dur[:18]={durs[:18]}", flush=True)
except Exception as e:
print(f" [skip {r.get('id')}] {type(e).__name__}: {e}", flush=True)
out.close(); print(f"DONE aligned {n_ok} rows -> {args.out}")
if __name__ == "__main__":
main()