PrimeTTS / scripts /asr_filter.py
Luigi's picture
reproduction: actual VoxCPM2-TW pipeline scripts + master run + eval set
9033a1b verified
Raw
History Blame Contribute Delete
3.5 kB
#!/usr/bin/env python3
"""ASR quality-gate the VoxCPM2 teacher corpus before alignment/training.
zh + mix -> Han-only CER via Breeze-ASR-25 (zh-TW); en -> WER via generic whisper.
(mix uses Han-only CER, so embedded English is ignored — we only verify the Chinese portion.)
Drops clips above threshold. Writes <out>.clean.jsonl (kept) + <out>.rejected.jsonl (with scores).
Run in moss-nano-venv (faster_whisper, opencc). GPU recommended.
CUDA_VISIBLE_DEVICES=0 python asr_filter.py --manifest voxcpm_tw_manifest.jsonl --out voxcpm_tw_manifest
"""
import argparse, json, re, sys
import numpy as np, soundfile as sf
import opencc
_t2s = opencc.OpenCC("t2s")
def _han(s): return re.sub(r"[^一-鿿]", "", s)
def _norm_zh(s): return _han(_t2s.convert(s or ""))
def _norm_en(s): return re.sub(r"[^a-z' ]", " ", (s or "").lower()).split()
def _lev(a, b):
m, n = len(a), len(b)
if m == 0: return n
prev = list(range(n + 1))
for i in range(1, m + 1):
cur = [i] + [0] * n
for j in range(1, n + 1):
cur[j] = min(prev[j] + 1, cur[j-1] + 1, prev[j-1] + (a[i-1] != b[j-1]))
prev = cur
return prev[n]
def _cer(ref, hyp): r = _norm_zh(ref); return _lev(list(r), list(_norm_zh(hyp))) / max(1, len(r))
def _wer(ref, hyp): r = _norm_en(ref); return _lev(r, _norm_en(hyp)) / max(1, len(r))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--manifest", required=True)
ap.add_argument("--out", required=True, help="prefix -> <out>.clean.jsonl / <out>.rejected.jsonl")
ap.add_argument("--device", default="cuda")
ap.add_argument("--compute-type", default="float16")
ap.add_argument("--cer-zh", type=float, default=0.12)
ap.add_argument("--cer-mix", type=float, default=0.15)
ap.add_argument("--wer-en", type=float, default=0.20)
a = ap.parse_args()
from faster_whisper import WhisperModel
breeze = WhisperModel("SoybeanMilk/faster-whisper-Breeze-ASR-25", device=a.device, compute_type=a.compute_type)
generic = WhisperModel("medium", device=a.device, compute_type=a.compute_type)
def asr(model, wav, lang):
segs, _ = model.transcribe(wav, language=lang, beam_size=1)
return "".join(s.text for s in segs)
rows = [json.loads(l) for l in open(a.manifest) if l.strip()]
kept = open(f"{a.out}.clean.jsonl", "w", encoding="utf-8")
rej = open(f"{a.out}.rejected.jsonl", "w", encoding="utf-8")
nk = nr = 0; n = 0
for r in rows:
wav, ref, lang = r["target_audio"], r["text"], r["lang"]
try:
if lang == "en":
score = _wer(ref, asr(generic, wav, "en")); thr = a.wer_en; metric = "wer"
else:
score = _cer(ref, asr(breeze, wav, "zh")); thr = (a.cer_mix if lang == "mix" else a.cer_zh); metric = "cer"
except Exception as e:
r["_err"] = str(e)[:80]; rej.write(json.dumps(r, ensure_ascii=False) + "\n"); nr += 1; continue
r[metric] = round(float(score), 3)
if score <= thr:
kept.write(json.dumps(r, ensure_ascii=False) + "\n"); nk += 1
else:
rej.write(json.dumps(r, ensure_ascii=False) + "\n"); nr += 1
n += 1
if n % 200 == 0:
kept.flush(); rej.flush()
print(f"{n}/{len(rows)} kept={nk} rej={nr} ({nr/n*100:.1f}% drop)", flush=True)
kept.close(); rej.close()
print(f"FILTER_DONE kept={nk} rejected={nr} drop={nr/max(1,nk+nr)*100:.1f}%", flush=True)
if __name__ == "__main__":
main()