PrimeTTS / scripts /select_diverse_text.py
Luigi's picture
PrimeTTS: full training pipeline + weights (fine-tune of Inflect-Nano-v1)
a37967e verified
Raw
History Blame Contribute Delete
4.17 kB
#!/usr/bin/env python3
"""Select a DIVERSE training-text set to maximize phoneme/character coverage.
Why this matters: at 4.63M params the model is NOT capacity-limited — but it can only
pronounce characters/words it has SEEN. A narrow corpus (e.g. a few hundred Han chars)
leaves most held-out characters unseen -> garbled output. This script builds a broad,
coverage-maximizing text set from Tatoeba so the teacher (and then the student) cover
the common vocabulary.
zh-TW: Tatoeba `cmn` -> OpenCC s2twp (Taiwan traditional) -> greedy Han-CHAR coverage.
en : Tatoeba `eng` -> greedy WORD coverage (English phones are few; word/prosody variety matters).
Usage:
python select_diverse_text.py --lang zh --n 6000 --out expand_zh.tsv
python select_diverse_text.py --lang en --n 6000 --out expand_en.tsv
Then feed the .tsv (id<TAB>text) to gen_breezy_corpus.py to synthesize the teacher audio.
Deps: requests/urllib (download), opencc (zh only). Tatoeba dumps are CC-BY 2.0 FR.
"""
import argparse, bz2, os, re, random, urllib.request
TATOEBA = "https://downloads.tatoeba.org/exports/per_language/{lang}/{lang}_sentences.tsv.bz2"
HAN = lambda s: set(c for c in s if "一" <= c <= "鿿")
def download(lang):
f = f"{lang}_sentences.tsv"
if not os.path.exists(f):
url = TATOEBA.format(lang=lang)
print("downloading", url)
urllib.request.urlretrieve(url, f + ".bz2")
with bz2.open(f + ".bz2", "rt", encoding="utf-8") as i, open(f, "w", encoding="utf-8") as o:
for line in i:
o.write(line)
return f
def select_zh(path, n, seed=42):
import opencc
cc = opencc.OpenCC("s2twp") # simplified -> Taiwan traditional (with phrase conversion)
allowed = set(",。!?、:;…")
seen_t, cands = set(), []
for l in open(path, encoding="utf-8"):
p = l.rstrip("\n").split("\t")
if len(p) < 3:
continue
t = cc.convert(p[2].strip()).replace(",", ",").replace("!", "!").replace("?", "?")
h = HAN(t)
if not (6 <= len(h) <= 26):
continue
if any(("一" <= c <= "鿿") is False and c not in allowed for c in t):
continue
if t not in seen_t:
seen_t.add(t); cands.append(t)
return greedy_cover(cands, HAN, n, seed)
def select_en(path, n, seed=42):
words = lambda s: set(re.findall(r"[a-z']+", s.lower()))
seen_t, cands = set(), []
for l in open(path, encoding="utf-8"):
p = l.rstrip("\n").split("\t")
if len(p) < 3:
continue
t = p[2].strip()
if not re.fullmatch(r"[A-Za-z0-9 ,.\-'?!]+", t):
continue
w = re.findall(r"[A-Za-z']+", t)
if not (4 <= len(w) <= 14) or any(len(x) > 15 for x in w):
continue
if t not in seen_t:
seen_t.add(t); cands.append(t)
return greedy_cover(cands, words, n, seed)
def greedy_cover(cands, unit, n, seed):
"""Greedy max-coverage of `unit(text)` items, then random top-up to n for frequency."""
random.seed(seed); random.shuffle(cands)
covered, selected, rest = set(), [], []
cands.sort(key=lambda t: len(unit(t) - covered), reverse=True)
for t in cands:
if len(unit(t) - covered) >= 1 and len(selected) < n:
selected.append(t); covered |= unit(t)
else:
rest.append(t)
random.shuffle(rest)
selected += rest[: max(0, n - len(selected))]
print(f"selected {len(selected)} sentences | unique units covered: {len(covered)}")
return selected
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--lang", choices=["zh", "en"], required=True)
ap.add_argument("--n", type=int, default=6000)
ap.add_argument("--out", required=True)
args = ap.parse_args()
path = download("cmn" if args.lang == "zh" else "eng")
sents = select_zh(path, args.n) if args.lang == "zh" else select_en(path, args.n)
with open(args.out, "w", encoding="utf-8") as o:
for i, t in enumerate(sents):
o.write(f"{args.lang}e{i:05d}\t{t}\n")
print("wrote", args.out, len(sents))
if __name__ == "__main__":
main()