File size: 4,170 Bytes
a37967e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
"""Select a DIVERSE training-text set to maximize phoneme/character coverage.

Why this matters: at 4.63M params the model is NOT capacity-limited — but it can only
pronounce characters/words it has SEEN. A narrow corpus (e.g. a few hundred Han chars)
leaves most held-out characters unseen -> garbled output. This script builds a broad,
coverage-maximizing text set from Tatoeba so the teacher (and then the student) cover
the common vocabulary.

zh-TW: Tatoeba `cmn` -> OpenCC s2twp (Taiwan traditional) -> greedy Han-CHAR coverage.
en   : Tatoeba `eng` -> greedy WORD coverage (English phones are few; word/prosody variety matters).

Usage:
  python select_diverse_text.py --lang zh --n 6000 --out expand_zh.tsv
  python select_diverse_text.py --lang en --n 6000 --out expand_en.tsv
Then feed the .tsv (id<TAB>text) to gen_breezy_corpus.py to synthesize the teacher audio.

Deps: requests/urllib (download), opencc (zh only). Tatoeba dumps are CC-BY 2.0 FR.
"""
import argparse, bz2, os, re, random, urllib.request

TATOEBA = "https://downloads.tatoeba.org/exports/per_language/{lang}/{lang}_sentences.tsv.bz2"
HAN = lambda s: set(c for c in s if "一" <= c <= "鿿")


def download(lang):
    f = f"{lang}_sentences.tsv"
    if not os.path.exists(f):
        url = TATOEBA.format(lang=lang)
        print("downloading", url)
        urllib.request.urlretrieve(url, f + ".bz2")
        with bz2.open(f + ".bz2", "rt", encoding="utf-8") as i, open(f, "w", encoding="utf-8") as o:
            for line in i:
                o.write(line)
    return f


def select_zh(path, n, seed=42):
    import opencc
    cc = opencc.OpenCC("s2twp")  # simplified -> Taiwan traditional (with phrase conversion)
    allowed = set(",。!?、:;…")
    seen_t, cands = set(), []
    for l in open(path, encoding="utf-8"):
        p = l.rstrip("\n").split("\t")
        if len(p) < 3:
            continue
        t = cc.convert(p[2].strip()).replace(",", ",").replace("!", "!").replace("?", "?")
        h = HAN(t)
        if not (6 <= len(h) <= 26):
            continue
        if any(("一" <= c <= "鿿") is False and c not in allowed for c in t):
            continue
        if t not in seen_t:
            seen_t.add(t); cands.append(t)
    return greedy_cover(cands, HAN, n, seed)


def select_en(path, n, seed=42):
    words = lambda s: set(re.findall(r"[a-z']+", s.lower()))
    seen_t, cands = set(), []
    for l in open(path, encoding="utf-8"):
        p = l.rstrip("\n").split("\t")
        if len(p) < 3:
            continue
        t = p[2].strip()
        if not re.fullmatch(r"[A-Za-z0-9 ,.\-'?!]+", t):
            continue
        w = re.findall(r"[A-Za-z']+", t)
        if not (4 <= len(w) <= 14) or any(len(x) > 15 for x in w):
            continue
        if t not in seen_t:
            seen_t.add(t); cands.append(t)
    return greedy_cover(cands, words, n, seed)


def greedy_cover(cands, unit, n, seed):
    """Greedy max-coverage of `unit(text)` items, then random top-up to n for frequency."""
    random.seed(seed); random.shuffle(cands)
    covered, selected, rest = set(), [], []
    cands.sort(key=lambda t: len(unit(t) - covered), reverse=True)
    for t in cands:
        if len(unit(t) - covered) >= 1 and len(selected) < n:
            selected.append(t); covered |= unit(t)
        else:
            rest.append(t)
    random.shuffle(rest)
    selected += rest[: max(0, n - len(selected))]
    print(f"selected {len(selected)} sentences | unique units covered: {len(covered)}")
    return selected


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--lang", choices=["zh", "en"], required=True)
    ap.add_argument("--n", type=int, default=6000)
    ap.add_argument("--out", required=True)
    args = ap.parse_args()
    path = download("cmn" if args.lang == "zh" else "eng")
    sents = select_zh(path, args.n) if args.lang == "zh" else select_en(path, args.n)
    with open(args.out, "w", encoding="utf-8") as o:
        for i, t in enumerate(sents):
            o.write(f"{args.lang}e{i:05d}\t{t}\n")
    print("wrote", args.out, len(sents))


if __name__ == "__main__":
    main()