DSTK / thirdparty /G2P /G2P_processors.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# -*- coding: utf-8 -*-
# This project combines the TN and G2P functions of https://github.com/RVC-Boss/GPT-SoVITS and https://github.com/wenet-e2e/WeTextProcessing
# Huawei Technologies Co., Ltd. (authors: Xiao Chen)
from text.cleaner import clean_text
import LangSegment
from text import symbols as symbols_v1
from TN_processors import punct_normalization, MultilingualTN, alphabet_normalization
import sys
class BaseG2P:
def __init__(self):
self.sil_symbol = '[SIL]'
self.comma_symbol = '[CM]' # ,
self.peroid_symbol = '[PD]' # .
self.question_symbol = '[QN]' # ?
self.exclamation_symbol = '[EX]' # !
self.punct_to_symbol = {',': self.comma_symbol, '.': self.peroid_symbol,
'!': self.exclamation_symbol, '?': self.question_symbol}
self.er_mapping = {'er1': ('e1', 'rr'), 'er2': ('e2', 'rr'), 'er3': ('e3', 'rr'), 'er4': ('e4', 'rr'),
'er5': ('e5', 'rr'), 'r5': ('e5', 'rr')}
pass
def replace_punct_with_symbol(self, phone_list):
rep_list = []
for ph in phone_list:
if ph in self.punct_to_symbol:
rep_list.append(self.punct_to_symbol[ph])
else:
rep_list.append(ph)
return rep_list
class MultilingualG2P(BaseG2P):
def __init__(self, module="wenet", remove_interjections=False, remove_erhua=True):
BaseG2P.__init__(self)
self.tn_module = module
self.language_module_map = {"zh": "chinese", "en": "english"}
self.version = "v1"
self.output_eng_word_boundary = False
self.output_chn_word_boundary = False
if self.tn_module == "wenet":
self.tn_wenet = MultilingualTN("wenet", remove_interjections, remove_erhua)
return
def set_output_eng_word_boundary(self, output_eng_word_boundary):
self.output_eng_word_boundary = output_eng_word_boundary
def set_output_chn_word_boundary(self, output_chn_word_boundary):
self.output_chn_word_boundary = output_chn_word_boundary
def g2p_for_norm_text(self, norm_text, language):
symbols = symbols_v1.symbols
if(language not in self.language_module_map):
language="zh"
text=" "
language_module = __import__("text."+self.language_module_map[language],fromlist=[self.language_module_map[language]])
if language == "zh":##########
phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph)
elif language == "en":
if self.output_eng_word_boundary:
phones = language_module.g2p_with_boundary(norm_text)
else:
phones = language_module.g2p(norm_text)
# if len(phones) < 4:
# phones = [','] + phones
word2ph = None
else:
phones = language_module.g2p(norm_text)
word2ph = None
phones = ['UNK' if ph not in symbols else ph for ph in phones]
return phones, word2ph
def text_normalization_and_g2p(self, text, language, with_lang_prefix=False, normalize_punct=False):
'''
language in {en, zh}, if language == "zh", this method supports mixture of English and Chinese input. if language == "en", this method only supports English input.
'''
if normalize_punct:
text = punct_normalization(text)
# print('norm text: ' + text)
text = alphabet_normalization(text)
text = text.lower()
if language in {"en"}:
language = language.replace("all_", "")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"]
for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日韩文汉字,以用户输入为准
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if self.tn_module == "baidu":
phones, word2ph, norm_text = clean_text(
formattext, language, self.version)
all_norm_text = norm_text
else:
norm_formattext = self.tn_wenet.normalize_segment(formattext, language, normalize_punct)
phones, word2ph = self.g2p_for_norm_text(norm_formattext, language)
all_norm_text = norm_formattext
elif language in {"zh", "auto"}:
textlist = []
langlist = []
LangSegment.setfilters(["en", "zh", "ja", "ko"])
# priority_lang = LangSegment.getfilters()
if language == "auto":
for tmp in LangSegment.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
#fix consecutive segment of same language
mergelist = []
for idx in range(len(textlist)):
if idx > 0 and langlist[idx - 1] == langlist[idx]:
mergelist.append(1)
else:
mergelist.append(0)
merged_textlist = []
merged_langlist = []
for idx in range(len(mergelist)):
if mergelist[idx] == 0:
merged_textlist.append(textlist[idx])
merged_langlist.append(langlist[idx])
else:
merged_textlist[-1] += " " + textlist[idx]
textlist = merged_textlist
langlist = merged_langlist
assert len(textlist) == len(langlist)
# print(textlist)
# print(langlist)
phones_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
if self.tn_module == "wenet":
norm_text = self.tn_wenet.normalize_segment(textlist[i], lang, normalize_punct)
phones, word2ph = self.g2p_for_norm_text(norm_text, lang)
else:
phones, word2ph, norm_text = clean_text(
textlist[i], lang, self.version)
# lang prefix is only for mix language
if with_lang_prefix:
phones_with_lang = []
for ph in phones:
if ph[0].isalpha():
phones_with_lang.append(lang + '_' + ph)
else:
phones_with_lang.append(ph)
phones_list.append(phones_with_lang)
else:
phones_list.append(phones)
norm_text_list.append(norm_text)
phones = sum(phones_list, [])
all_norm_text = ' '.join(norm_text_list)
# if not final and len(phones) < 6:
# return text_normalization_and_g2p("." + text,language,version,final=True)
if normalize_punct:
phones = self.replace_punct_with_symbol(phones)
return phones, all_norm_text
if __name__ == '__main__':
'''
Testing functions
'''
# text = '1983年2月,旅行了2天的儿童和长翅膀的女孩儿:“︘菜单修订后有鱼香肉丝儿、『王道椒香鸡腿〕和川蜀鸡翅?……”it\'s a test 112.王会计会计算机。which had been in force since 1760.调查员决定调节调查的难度。Article VI, Qing government would be charged an annual interest rate of 5% for the money.√2和π是不是无理数?'
# text = '马打兰王国(732-1006),是8世纪到10世纪期间,存在于中爪哇的一个印度化王国。'
language = 'zh' # zh means the model treats all non-English as Chinese, en means the model treats all langauge as English.
mG2P = MultilingualG2P("wenet", remove_interjections=False, remove_erhua=False) # 'baidu' or 'wenet'
mG2P.set_output_eng_word_boundary(True)
sys.stdout.write("Input: ")
sys.stdout.flush()
for line in sys.stdin:
if line.strip() == "exit()":
exit()
if len(line.strip()) <= 0:
sys.stdout.write("Input: ")
sys.stdout.flush()
continue
phones, norm_text = mG2P.text_normalization_and_g2p(
line.strip(), language, with_lang_prefix=True, normalize_punct=True)
sys.stdout.write("Norm Text: " + norm_text + "\n")
sys.stdout.write("phonemes: " + " ".join(phones) + "\n")
sys.stdout.write("Input: ")
sys.stdout.flush()