#!/usr/bin/env python3 """v4 aligner: PHONE-LEVEL forced alignment via espeak phoneme-CTC. Root-cause fix. v1/v3 did CHAR-level CTC then heuristically SPLIT each char span across its phones -> wrong RELATIVE phone durations -> over-smoothed acoustic. v4 aligns OUR frontend's phone sequence DIRECTLY: each phone -> one IPA token in facebook/wav2vec2-lv-60-espeak-cv-ft's vocab -> torchaudio.forced_align(emissions, our_phone_ipa_targets) -> per-phone frames. No splitting. Durations are contiguous (cover all frames); fit_durations() rescales to true mel length at train time, so only RELATIVE proportions matter -- and those are now correct from the phone recognizer. Non-alignable phones (SP / punctuation / OOV) carry a small relative value. """ from __future__ import annotations import argparse, json, sys import numpy as np, torch, soundfile as sf, librosa sys.path.insert(0, "/home/luigi/jetson-tts/mossnano/zhtw8k") ESPEAK = "facebook/wav2vec2-lv-60-espeak-cv-ft" # ARPABET (en, stress stripped) -> single IPA token present in the espeak vocab. ARPA2IPA = { 'AA':'ɑ','AE':'æ','AH':'ʌ','AO':'ɔ','AW':'aʊ','AY':'aɪ','B':'b','CH':'tʃ', 'D':'d','DH':'ð','EH':'ɛ','ER':'ɚ','EY':'eɪ','F':'f','G':'ɡ','HH':'h', 'IH':'ɪ','IY':'i','JH':'dʒ','K':'k','L':'l','M':'m','N':'n','NG':'ŋ', 'OW':'oʊ','OY':'ɔɪ','P':'p','R':'ɹ','S':'s','SH':'ʃ','T':'t','TH':'θ', 'UH':'ʊ','UW':'u','V':'v','W':'w','Y':'j','Z':'z','ZH':'ʒ', } # bopomofo (zh) -> toneless IPA token in the espeak Mandarin inventory (best-effort). BOPO2IPA = { 'ㄅ':'p','ㄆ':'pʰ','ㄇ':'m','ㄈ':'f','ㄉ':'t','ㄊ':'tʰ','ㄋ':'n','ㄌ':'l', 'ㄍ':'k','ㄎ':'kʰ','ㄏ':'x','ㄐ':'tɕ','ㄑ':'tɕh','ㄒ':'ɕ','ㄓ':'tʃ','ㄔ':'tʃʰ', 'ㄕ':'ʂ','ㄖ':'ʐ','ㄗ':'ts','ㄘ':'tsh','ㄙ':'s', 'ㄧ':'i','ㄨ':'u','ㄩ':'y','ㄚ':'a','ㄛ':'o','ㄜ':'ɤ','ㄝ':'e','ㄞ':'ai', 'ㄟ':'ei','ㄠ':'au','ㄡ':'ou','ㄢ':'a','ㄣ':'ə','ㄤ':'ɑ','ㄥ':'ə','ㄦ':'ɚ', 'ㄭ':'ɨ', } def build_frontend_with_charidx(): import frontend_bopomofo as F import text_norm import re F._lazy() def run(text): text = text_norm.normalize(text) bopo = F._g2pw(text)[0] chars = list(text) phones, tones, langs = [], [], [] i = 0 while i < len(chars): b = bopo[i] if i < len(bopo) else None ch = chars[i] if b is not None: units, tone = F._split_syllable(b) for u in units: phones.append(u); tones.append(min(tone,5)); langs.append(0) i += 1 elif re.match(r'[A-Za-z]', ch): j = i while j < len(chars) and re.match(r"[A-Za-z']", chars[j]): j += 1 for p in F._g2pen(''.join(chars[i:j])): p = p.strip() if not p: continue if p[-1].isdigit(): st=int(p[-1]); p=p[:-1] else: st=0 if p in F.SYM2ID: phones.append(p); tones.append(st); langs.append(1) phones.append('SP'); tones.append(0); langs.append(1) i = j else: if ch in F.PUNCT: phones.append(ch); tones.append(0); langs.append(0) elif ch.strip()=='' and phones and phones[-1]!='SP': phones.append('SP'); tones.append(0); langs.append(0) i += 1 return text, phones, tones, langs return run, F class Aligner: def __init__(self, device): from transformers import Wav2Vec2ForCTC from huggingface_hub import hf_hub_download self.vocab = json.load(open(hf_hub_download(ESPEAK, "vocab.json"))) self.blank = self.vocab[""] self.model = Wav2Vec2ForCTC.from_pretrained(ESPEAK).to(device).eval() self.device = device def emissions(self, audio16k): iv = torch.from_numpy(audio16k).float().unsqueeze(0).to(self.device) with torch.inference_mode(): logits = self.model(iv).logits[0] # [T,V] return torch.log_softmax(logits, dim=-1).cpu() def durations(self, audio16k, phones, langs): """Return integer relative duration per phone (contiguous CTC spans).""" from torchaudio.functional import forced_align emis = self.emissions(audio16k) T = emis.shape[0] # target = alignable phones mapped to a vocab token id tgt_ids, tgt_pi = [], [] for pi, (p, lg) in enumerate(zip(phones, langs)): tok = ARPA2IPA.get(p) if lg == 1 else BOPO2IPA.get(p) if tok is not None and tok in self.vocab: tgt_ids.append(self.vocab[tok]); tgt_pi.append(pi) n = len(phones); dur = [0.0]*n if len(tgt_ids) < 1 or len(tgt_ids) > T: # fall back: uniform for pi in range(n): dur[pi] = 1.0 return [max(1,int(round(x))) for x in dur], 0 tokens = torch.tensor([tgt_ids], dtype=torch.int32) try: aligned, _ = forced_align(emis.unsqueeze(0), tokens, blank=self.blank) except Exception: for pi in range(n): dur[pi] = 1.0 return [max(1,int(round(x))) for x in dur], 0 path = aligned[0].tolist() # first frame of each emitted target token (in order) starts = []; prev = self.blank for fi, tk in enumerate(path): if tk != self.blank and tk != prev: starts.append(fi) prev = tk L = min(len(starts), len(tgt_pi)) # contiguous span per alignable phone: [start_k, start_{k+1}) for k in range(L): s = starts[k] e = starts[k+1] if k+1 < len(starts) else T dur[tgt_pi[k]] = max(1.0, float(e - s)) # leading silence -> give to first alignable phone if L >= 1 and starts[0] > 0: dur[tgt_pi[0]] += starts[0] # non-alignable phones (SP / punct / OOV): small relative value anchored = [d for d in dur if d > 0] base = (sum(anchored)/len(anchored)) if anchored else 4.0 for pi in range(n): if dur[pi] == 0: p = phones[pi] dur[pi] = base*0.5 if p == 'SP' else base*0.25 return [max(1,int(round(x))) for x in dur], len(tgt_pi) def main(): ap = argparse.ArgumentParser() ap.add_argument("--manifest", required=True) ap.add_argument("--out", required=True) ap.add_argument("--limit", type=int, default=0) ap.add_argument("--device", default="cuda") args = ap.parse_args() dev = args.device if torch.cuda.is_available() else "cpu" run, F = build_frontend_with_charidx() al = Aligner(dev) rows = [json.loads(l) for l in open(args.manifest)] if args.limit: rows = rows[:args.limit] out = open(args.out, "w", encoding="utf-8"); n_ok = 0 for r in rows: wav = r.get("wav") or r.get("target_audio") try: a, sr = sf.read(wav, dtype="float32") if a.ndim > 1: a = a.mean(1) a16 = librosa.resample(a, orig_sr=sr, target_sr=16000) if sr != 16000 else a _, phones, tones, langs = run(r["text"]) durs, n_anchor = al.durations(a16, phones, langs) phone_ids = [F.SYM2ID.get(p, F.SYM2ID['UNK']) for p in phones] out.write(json.dumps({ "id": r["id"], "text": r["text"], "target_audio": wav, "phone_ids": phone_ids, "tone_ids": tones, "lang_ids": langs, "hifigan_durations": durs, "speaker_id": 0, }, ensure_ascii=False)+"\n") n_ok += 1 if n_ok <= 5: print(f" {r['id']} n_ph={len(phones)} anchored={n_anchor} " f"dur_sum={sum(durs)} dur[:18]={durs[:18]}", flush=True) except Exception as e: print(f" [skip {r.get('id')}] {type(e).__name__}: {e}", flush=True) out.close(); print(f"DONE aligned {n_ok} rows -> {args.out}") if __name__ == "__main__": main()