|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from text.cleaner import clean_text |
|
|
import LangSegment |
|
|
from text import symbols as symbols_v1 |
|
|
from TN_processors import punct_normalization, MultilingualTN, alphabet_normalization |
|
|
import sys |
|
|
|
|
|
class BaseG2P: |
|
|
|
|
|
def __init__(self): |
|
|
self.sil_symbol = '[SIL]' |
|
|
self.comma_symbol = '[CM]' |
|
|
self.peroid_symbol = '[PD]' |
|
|
self.question_symbol = '[QN]' |
|
|
self.exclamation_symbol = '[EX]' |
|
|
|
|
|
self.punct_to_symbol = {',': self.comma_symbol, '.': self.peroid_symbol, |
|
|
'!': self.exclamation_symbol, '?': self.question_symbol} |
|
|
self.er_mapping = {'er1': ('e1', 'rr'), 'er2': ('e2', 'rr'), 'er3': ('e3', 'rr'), 'er4': ('e4', 'rr'), |
|
|
'er5': ('e5', 'rr'), 'r5': ('e5', 'rr')} |
|
|
pass |
|
|
|
|
|
def replace_punct_with_symbol(self, phone_list): |
|
|
rep_list = [] |
|
|
for ph in phone_list: |
|
|
if ph in self.punct_to_symbol: |
|
|
rep_list.append(self.punct_to_symbol[ph]) |
|
|
else: |
|
|
rep_list.append(ph) |
|
|
return rep_list |
|
|
|
|
|
|
|
|
class MultilingualG2P(BaseG2P): |
|
|
|
|
|
def __init__(self, module="wenet", remove_interjections=False, remove_erhua=True): |
|
|
BaseG2P.__init__(self) |
|
|
self.tn_module = module |
|
|
self.language_module_map = {"zh": "chinese", "en": "english"} |
|
|
self.version = "v1" |
|
|
self.output_eng_word_boundary = False |
|
|
self.output_chn_word_boundary = False |
|
|
if self.tn_module == "wenet": |
|
|
self.tn_wenet = MultilingualTN("wenet", remove_interjections, remove_erhua) |
|
|
return |
|
|
|
|
|
def set_output_eng_word_boundary(self, output_eng_word_boundary): |
|
|
self.output_eng_word_boundary = output_eng_word_boundary |
|
|
|
|
|
def set_output_chn_word_boundary(self, output_chn_word_boundary): |
|
|
self.output_chn_word_boundary = output_chn_word_boundary |
|
|
|
|
|
def g2p_for_norm_text(self, norm_text, language): |
|
|
symbols = symbols_v1.symbols |
|
|
if(language not in self.language_module_map): |
|
|
language="zh" |
|
|
text=" " |
|
|
language_module = __import__("text."+self.language_module_map[language],fromlist=[self.language_module_map[language]]) |
|
|
if language == "zh": |
|
|
phones, word2ph = language_module.g2p(norm_text) |
|
|
assert len(phones) == sum(word2ph) |
|
|
assert len(norm_text) == len(word2ph) |
|
|
elif language == "en": |
|
|
if self.output_eng_word_boundary: |
|
|
phones = language_module.g2p_with_boundary(norm_text) |
|
|
else: |
|
|
phones = language_module.g2p(norm_text) |
|
|
|
|
|
|
|
|
word2ph = None |
|
|
else: |
|
|
phones = language_module.g2p(norm_text) |
|
|
word2ph = None |
|
|
phones = ['UNK' if ph not in symbols else ph for ph in phones] |
|
|
return phones, word2ph |
|
|
|
|
|
def text_normalization_and_g2p(self, text, language, with_lang_prefix=False, normalize_punct=False): |
|
|
''' |
|
|
language in {en, zh}, if language == "zh", this method supports mixture of English and Chinese input. if language == "en", this method only supports English input. |
|
|
''' |
|
|
if normalize_punct: |
|
|
text = punct_normalization(text) |
|
|
|
|
|
text = alphabet_normalization(text) |
|
|
text = text.lower() |
|
|
|
|
|
if language in {"en"}: |
|
|
language = language.replace("all_", "") |
|
|
if language == "en": |
|
|
LangSegment.setfilters(["en"]) |
|
|
formattext = " ".join(tmp["text"] |
|
|
for tmp in LangSegment.getTexts(text)) |
|
|
else: |
|
|
|
|
|
formattext = text |
|
|
while " " in formattext: |
|
|
formattext = formattext.replace(" ", " ") |
|
|
if self.tn_module == "baidu": |
|
|
phones, word2ph, norm_text = clean_text( |
|
|
formattext, language, self.version) |
|
|
all_norm_text = norm_text |
|
|
else: |
|
|
norm_formattext = self.tn_wenet.normalize_segment(formattext, language, normalize_punct) |
|
|
phones, word2ph = self.g2p_for_norm_text(norm_formattext, language) |
|
|
all_norm_text = norm_formattext |
|
|
elif language in {"zh", "auto"}: |
|
|
textlist = [] |
|
|
langlist = [] |
|
|
LangSegment.setfilters(["en", "zh", "ja", "ko"]) |
|
|
|
|
|
if language == "auto": |
|
|
for tmp in LangSegment.getTexts(text): |
|
|
langlist.append(tmp["lang"]) |
|
|
textlist.append(tmp["text"]) |
|
|
else: |
|
|
for tmp in LangSegment.getTexts(text): |
|
|
if tmp["lang"] == "en": |
|
|
langlist.append(tmp["lang"]) |
|
|
else: |
|
|
|
|
|
langlist.append(language) |
|
|
textlist.append(tmp["text"]) |
|
|
|
|
|
|
|
|
mergelist = [] |
|
|
for idx in range(len(textlist)): |
|
|
if idx > 0 and langlist[idx - 1] == langlist[idx]: |
|
|
mergelist.append(1) |
|
|
else: |
|
|
mergelist.append(0) |
|
|
merged_textlist = [] |
|
|
merged_langlist = [] |
|
|
for idx in range(len(mergelist)): |
|
|
if mergelist[idx] == 0: |
|
|
merged_textlist.append(textlist[idx]) |
|
|
merged_langlist.append(langlist[idx]) |
|
|
else: |
|
|
merged_textlist[-1] += " " + textlist[idx] |
|
|
|
|
|
textlist = merged_textlist |
|
|
langlist = merged_langlist |
|
|
|
|
|
assert len(textlist) == len(langlist) |
|
|
|
|
|
|
|
|
|
|
|
phones_list = [] |
|
|
norm_text_list = [] |
|
|
for i in range(len(textlist)): |
|
|
lang = langlist[i] |
|
|
if self.tn_module == "wenet": |
|
|
norm_text = self.tn_wenet.normalize_segment(textlist[i], lang, normalize_punct) |
|
|
phones, word2ph = self.g2p_for_norm_text(norm_text, lang) |
|
|
else: |
|
|
phones, word2ph, norm_text = clean_text( |
|
|
textlist[i], lang, self.version) |
|
|
|
|
|
if with_lang_prefix: |
|
|
phones_with_lang = [] |
|
|
for ph in phones: |
|
|
if ph[0].isalpha(): |
|
|
phones_with_lang.append(lang + '_' + ph) |
|
|
else: |
|
|
phones_with_lang.append(ph) |
|
|
phones_list.append(phones_with_lang) |
|
|
else: |
|
|
phones_list.append(phones) |
|
|
norm_text_list.append(norm_text) |
|
|
phones = sum(phones_list, []) |
|
|
all_norm_text = ' '.join(norm_text_list) |
|
|
|
|
|
|
|
|
|
|
|
if normalize_punct: |
|
|
phones = self.replace_punct_with_symbol(phones) |
|
|
|
|
|
return phones, all_norm_text |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
''' |
|
|
Testing functions |
|
|
''' |
|
|
|
|
|
|
|
|
language = 'zh' |
|
|
mG2P = MultilingualG2P("wenet", remove_interjections=False, remove_erhua=False) |
|
|
mG2P.set_output_eng_word_boundary(True) |
|
|
sys.stdout.write("Input: ") |
|
|
sys.stdout.flush() |
|
|
for line in sys.stdin: |
|
|
if line.strip() == "exit()": |
|
|
exit() |
|
|
if len(line.strip()) <= 0: |
|
|
sys.stdout.write("Input: ") |
|
|
sys.stdout.flush() |
|
|
continue |
|
|
phones, norm_text = mG2P.text_normalization_and_g2p( |
|
|
line.strip(), language, with_lang_prefix=True, normalize_punct=True) |
|
|
sys.stdout.write("Norm Text: " + norm_text + "\n") |
|
|
sys.stdout.write("phonemes: " + " ".join(phones) + "\n") |
|
|
sys.stdout.write("Input: ") |
|
|
sys.stdout.flush() |
|
|
|