| |
| """v4 aligner: PHONE-LEVEL forced alignment via espeak phoneme-CTC. |
| |
| Root-cause fix. v1/v3 did CHAR-level CTC then heuristically SPLIT each char span |
| across its phones -> wrong RELATIVE phone durations -> over-smoothed acoustic. |
| v4 aligns OUR frontend's phone sequence DIRECTLY: |
| each phone -> one IPA token in facebook/wav2vec2-lv-60-espeak-cv-ft's vocab |
| -> torchaudio.forced_align(emissions, our_phone_ipa_targets) -> per-phone frames. |
| No splitting. Durations are contiguous (cover all frames); fit_durations() rescales |
| to true mel length at train time, so only RELATIVE proportions matter -- and those |
| are now correct from the phone recognizer. |
| |
| Non-alignable phones (SP / punctuation / OOV) carry a small relative value. |
| """ |
| from __future__ import annotations |
| import argparse, json, sys |
| import numpy as np, torch, soundfile as sf, librosa |
| sys.path.insert(0, "/home/luigi/jetson-tts/mossnano/zhtw8k") |
|
|
| ESPEAK = "facebook/wav2vec2-lv-60-espeak-cv-ft" |
|
|
| |
| ARPA2IPA = { |
| 'AA':'ɑ','AE':'æ','AH':'ʌ','AO':'ɔ','AW':'aʊ','AY':'aɪ','B':'b','CH':'tʃ', |
| 'D':'d','DH':'ð','EH':'ɛ','ER':'ɚ','EY':'eɪ','F':'f','G':'ɡ','HH':'h', |
| 'IH':'ɪ','IY':'i','JH':'dʒ','K':'k','L':'l','M':'m','N':'n','NG':'ŋ', |
| 'OW':'oʊ','OY':'ɔɪ','P':'p','R':'ɹ','S':'s','SH':'ʃ','T':'t','TH':'θ', |
| 'UH':'ʊ','UW':'u','V':'v','W':'w','Y':'j','Z':'z','ZH':'ʒ', |
| } |
| |
| BOPO2IPA = { |
| 'ㄅ':'p','ㄆ':'pʰ','ㄇ':'m','ㄈ':'f','ㄉ':'t','ㄊ':'tʰ','ㄋ':'n','ㄌ':'l', |
| 'ㄍ':'k','ㄎ':'kʰ','ㄏ':'x','ㄐ':'tɕ','ㄑ':'tɕh','ㄒ':'ɕ','ㄓ':'tʃ','ㄔ':'tʃʰ', |
| 'ㄕ':'ʂ','ㄖ':'ʐ','ㄗ':'ts','ㄘ':'tsh','ㄙ':'s', |
| 'ㄧ':'i','ㄨ':'u','ㄩ':'y','ㄚ':'a','ㄛ':'o','ㄜ':'ɤ','ㄝ':'e','ㄞ':'ai', |
| 'ㄟ':'ei','ㄠ':'au','ㄡ':'ou','ㄢ':'a','ㄣ':'ə','ㄤ':'ɑ','ㄥ':'ə','ㄦ':'ɚ', |
| 'ㄭ':'ɨ', |
| } |
|
|
|
|
| def build_frontend_with_charidx(): |
| import frontend_bopomofo as F |
| import text_norm |
| import re |
| F._lazy() |
| def run(text): |
| text = text_norm.normalize(text) |
| bopo = F._g2pw(text)[0] |
| chars = list(text) |
| phones, tones, langs = [], [], [] |
| i = 0 |
| while i < len(chars): |
| b = bopo[i] if i < len(bopo) else None |
| ch = chars[i] |
| if b is not None: |
| units, tone = F._split_syllable(b) |
| for u in units: |
| phones.append(u); tones.append(min(tone,5)); langs.append(0) |
| i += 1 |
| elif re.match(r'[A-Za-z]', ch): |
| j = i |
| while j < len(chars) and re.match(r"[A-Za-z']", chars[j]): j += 1 |
| for p in F._g2pen(''.join(chars[i:j])): |
| p = p.strip() |
| if not p: continue |
| if p[-1].isdigit(): st=int(p[-1]); p=p[:-1] |
| else: st=0 |
| if p in F.SYM2ID: |
| phones.append(p); tones.append(st); langs.append(1) |
| phones.append('SP'); tones.append(0); langs.append(1) |
| i = j |
| else: |
| if ch in F.PUNCT: |
| phones.append(ch); tones.append(0); langs.append(0) |
| elif ch.strip()=='' and phones and phones[-1]!='SP': |
| phones.append('SP'); tones.append(0); langs.append(0) |
| i += 1 |
| return text, phones, tones, langs |
| return run, F |
|
|
|
|
| class Aligner: |
| def __init__(self, device): |
| from transformers import Wav2Vec2ForCTC |
| from huggingface_hub import hf_hub_download |
| self.vocab = json.load(open(hf_hub_download(ESPEAK, "vocab.json"))) |
| self.blank = self.vocab["<pad>"] |
| self.model = Wav2Vec2ForCTC.from_pretrained(ESPEAK).to(device).eval() |
| self.device = device |
|
|
| def emissions(self, audio16k): |
| iv = torch.from_numpy(audio16k).float().unsqueeze(0).to(self.device) |
| with torch.inference_mode(): |
| logits = self.model(iv).logits[0] |
| return torch.log_softmax(logits, dim=-1).cpu() |
|
|
| def durations(self, audio16k, phones, langs): |
| """Return integer relative duration per phone (contiguous CTC spans).""" |
| from torchaudio.functional import forced_align |
| emis = self.emissions(audio16k) |
| T = emis.shape[0] |
| |
| tgt_ids, tgt_pi = [], [] |
| for pi, (p, lg) in enumerate(zip(phones, langs)): |
| tok = ARPA2IPA.get(p) if lg == 1 else BOPO2IPA.get(p) |
| if tok is not None and tok in self.vocab: |
| tgt_ids.append(self.vocab[tok]); tgt_pi.append(pi) |
| n = len(phones); dur = [0.0]*n |
| if len(tgt_ids) < 1 or len(tgt_ids) > T: |
| |
| for pi in range(n): dur[pi] = 1.0 |
| return [max(1,int(round(x))) for x in dur], 0 |
| tokens = torch.tensor([tgt_ids], dtype=torch.int32) |
| try: |
| aligned, _ = forced_align(emis.unsqueeze(0), tokens, blank=self.blank) |
| except Exception: |
| for pi in range(n): dur[pi] = 1.0 |
| return [max(1,int(round(x))) for x in dur], 0 |
| path = aligned[0].tolist() |
| |
| starts = []; prev = self.blank |
| for fi, tk in enumerate(path): |
| if tk != self.blank and tk != prev: |
| starts.append(fi) |
| prev = tk |
| L = min(len(starts), len(tgt_pi)) |
| |
| for k in range(L): |
| s = starts[k] |
| e = starts[k+1] if k+1 < len(starts) else T |
| dur[tgt_pi[k]] = max(1.0, float(e - s)) |
| |
| if L >= 1 and starts[0] > 0: |
| dur[tgt_pi[0]] += starts[0] |
| |
| anchored = [d for d in dur if d > 0] |
| base = (sum(anchored)/len(anchored)) if anchored else 4.0 |
| for pi in range(n): |
| if dur[pi] == 0: |
| p = phones[pi] |
| dur[pi] = base*0.5 if p == 'SP' else base*0.25 |
| return [max(1,int(round(x))) for x in dur], len(tgt_pi) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--manifest", required=True) |
| ap.add_argument("--out", required=True) |
| ap.add_argument("--limit", type=int, default=0) |
| ap.add_argument("--device", default="cuda") |
| args = ap.parse_args() |
| dev = args.device if torch.cuda.is_available() else "cpu" |
| run, F = build_frontend_with_charidx() |
| al = Aligner(dev) |
| rows = [json.loads(l) for l in open(args.manifest)] |
| if args.limit: rows = rows[:args.limit] |
| out = open(args.out, "w", encoding="utf-8"); n_ok = 0 |
| for r in rows: |
| wav = r.get("wav") or r.get("target_audio") |
| try: |
| a, sr = sf.read(wav, dtype="float32") |
| if a.ndim > 1: a = a.mean(1) |
| a16 = librosa.resample(a, orig_sr=sr, target_sr=16000) if sr != 16000 else a |
| _, phones, tones, langs = run(r["text"]) |
| durs, n_anchor = al.durations(a16, phones, langs) |
| phone_ids = [F.SYM2ID.get(p, F.SYM2ID['UNK']) for p in phones] |
| out.write(json.dumps({ |
| "id": r["id"], "text": r["text"], "target_audio": wav, |
| "phone_ids": phone_ids, "tone_ids": tones, "lang_ids": langs, |
| "hifigan_durations": durs, "speaker_id": 0, |
| }, ensure_ascii=False)+"\n") |
| n_ok += 1 |
| if n_ok <= 5: |
| print(f" {r['id']} n_ph={len(phones)} anchored={n_anchor} " |
| f"dur_sum={sum(durs)} dur[:18]={durs[:18]}", flush=True) |
| except Exception as e: |
| print(f" [skip {r.get('id')}] {type(e).__name__}: {e}", flush=True) |
| out.close(); print(f"DONE aligned {n_ok} rows -> {args.out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|