File size: 3,456 Bytes
f4c80a2
 
 
 
 
 
 
a50183d
 
f4c80a2
987c46e
f4c80a2
 
 
 
 
 
 
cbe5f84
f4c80a2
 
 
 
 
4d8ad2d
f4c80a2
 
e285e98
 
f4c80a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d8ad2d
 
f4c80a2
 
4d8ad2d
025d5b1
4d8ad2d
f4c80a2
 
 
 
 
 
4d8ad2d
 
 
 
 
025d5b1
f4c80a2
 
a50183d
 
 
 
 
 
 
 
f4c80a2
a50183d
 
 
 
 
 
 
 
 
 
 
f4c80a2
a50183d
 
 
 
 
f4c80a2
a50183d
 
f4c80a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import json
import warnings
from typing import List
import re

from resource.pinyin_dict import PINYIN_DICT
from pypinyin import pinyin, Style
from zhconv import convert


def preprocess_input(src_str, seg_syb=" "):
    src_str = src_str.replace("\n", seg_syb)
    src_str = src_str.replace(" ", seg_syb)
    return src_str


def postprocess_phn(phns, model_name, lang):
    if model_name == "espnet/aceopencpop_svs_visinger2_40singer_pretrain":
        return phns
    return [phn + "@" + lang for phn in phns]


def pyopenjtalk_g2p(text) -> List[str]:
    import pyopenjtalk
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        # add space between each character
        text = " ".join(list(text))
        # phones is a str object separated by space
        phones = pyopenjtalk.g2p(text, kana=False)
        if len(w) > 0:
            for warning in w:
                if "No phoneme" in str(warning.message):
                    return False
    phones = phones.split(" ")
    return phones


def split_pinyin_ace(pinyin: str, zh_plan: dict) -> tuple[str]:
    # load pinyin dict from local/pinyin.dict
    pinyin = pinyin.lower()
    if pinyin in zh_plan["dict"]:
        return zh_plan["dict"][pinyin]
    elif pinyin in zh_plan["syllable_alias"]:
        return zh_plan["dict"][zh_plan["syllable_alias"][pinyin]]
    else:
        return False


def split_pinyin_py(pinyin: str) -> tuple[str]:
    pinyin = pinyin.lower()
    if pinyin in PINYIN_DICT:
        return PINYIN_DICT[pinyin]
    else:
        return False


def get_tokenizer(model, lang):
    if model == "espnet/aceopencpop_svs_visinger2_40singer_pretrain":
        if lang == "zh":
            return lambda text: split_pinyin_py(text)
        else:
            raise ValueError(f"Only support Chinese language for {model}")
    elif model == "espnet/mixdata_svs_visinger2_spkemb_lang_pretrained":
        if lang == "zh":
            with open(os.path.join("resource/all_plans.json"), "r") as f:
                all_plan_dict = json.load(f)
            for plan in all_plan_dict["plans"]:
                if plan["language"] == "zh":
                    zh_plan = plan
            return lambda text: split_pinyin_ace(text, zh_plan)
        elif lang == "jp":
            return pyopenjtalk_g2p
        else:
            raise ValueError(f"Only support Chinese and Japanese language for {model}")
    else:
        raise ValueError(f"Only support espnet/aceopencpop_svs_visinger2_40singer_pretrain and espnet/mixdata_svs_visinger2_spkemb_lang_pretrained for now")


def is_chinese(char):
    return '\u4e00' <= char <= '\u9fff'


def is_special(char):
    return re.match(r'^[-——APSP]+$', char) is not None


def get_pinyin(texts):
    texts = preprocess_input(texts, seg_syb="")
    pattern = re.compile(r'[\u4e00-\u9fff]|[^\u4e00-\u9fff]+')
    blocks = pattern.findall(texts) 

    characters = [block for block in blocks if is_chinese(block)] 
    chinese_text = ''.join(characters)
    chinese_text = convert(chinese_text, 'zh-cn')
    
    chinese_pinyin = pinyin(chinese_text, style=Style.NORMAL)
    chinese_pinyin = [item[0] for item in chinese_pinyin]
    
    text_list = []
    pinyin_idx = 0
    for block in blocks:
        if is_chinese(block):
            text_list.append(chinese_pinyin[pinyin_idx])
            pinyin_idx += 1
        else:
            text_list.append(block)
    
    return text_list