|
|
|
|
|
import re |
|
|
import regex |
|
|
import inflect |
|
|
from functools import partial |
|
|
from wetext import Normalizer |
|
|
|
|
|
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') |
|
|
|
|
|
|
|
|
def contains_chinese(text): |
|
|
return bool(chinese_char_pattern.search(text)) |
|
|
|
|
|
|
|
|
|
|
|
def replace_corner_mark(text): |
|
|
text = text.replace('²', '平方') |
|
|
text = text.replace('³', '立方') |
|
|
text = text.replace('√', '根号') |
|
|
text = text.replace('≈', '约等于') |
|
|
text = text.replace('<', '小于') |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
def remove_bracket(text): |
|
|
text = text.replace('(', ' ').replace(')', ' ') |
|
|
text = text.replace('【', ' ').replace('】', ' ') |
|
|
text = text.replace('`', '').replace('`', '') |
|
|
text = text.replace("——", " ") |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
def spell_out_number(text: str, inflect_parser): |
|
|
new_text = [] |
|
|
st = None |
|
|
for i, c in enumerate(text): |
|
|
if not c.isdigit(): |
|
|
if st is not None: |
|
|
num_str = inflect_parser.number_to_words(text[st: i]) |
|
|
new_text.append(num_str) |
|
|
st = None |
|
|
new_text.append(c) |
|
|
else: |
|
|
if st is None: |
|
|
st = i |
|
|
if st is not None and st < len(text): |
|
|
num_str = inflect_parser.number_to_words(text[st:]) |
|
|
new_text.append(num_str) |
|
|
return ''.join(new_text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): |
|
|
def calc_utt_length(_text: str): |
|
|
if lang == "zh": |
|
|
return len(_text) |
|
|
else: |
|
|
return len(tokenize(_text)) |
|
|
|
|
|
def should_merge(_text: str): |
|
|
if lang == "zh": |
|
|
return len(_text) < merge_len |
|
|
else: |
|
|
return len(tokenize(_text)) < merge_len |
|
|
|
|
|
if lang == "zh": |
|
|
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';'] |
|
|
else: |
|
|
pounc = ['.', '?', '!', ';', ':'] |
|
|
if comma_split: |
|
|
pounc.extend([',', ',']) |
|
|
st = 0 |
|
|
utts = [] |
|
|
for i, c in enumerate(text): |
|
|
if c in pounc: |
|
|
if len(text[st: i]) > 0: |
|
|
utts.append(text[st: i] + c) |
|
|
if i + 1 < len(text) and text[i + 1] in ['"', '”']: |
|
|
tmp = utts.pop(-1) |
|
|
utts.append(tmp + text[i + 1]) |
|
|
st = i + 2 |
|
|
else: |
|
|
st = i + 1 |
|
|
if len(utts) == 0: |
|
|
if lang == "zh": |
|
|
utts.append(text + '。') |
|
|
else: |
|
|
utts.append(text + '.') |
|
|
final_utts = [] |
|
|
cur_utt = "" |
|
|
for utt in utts: |
|
|
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: |
|
|
final_utts.append(cur_utt) |
|
|
cur_utt = "" |
|
|
cur_utt = cur_utt + utt |
|
|
if len(cur_utt) > 0: |
|
|
if should_merge(cur_utt) and len(final_utts) != 0: |
|
|
final_utts[-1] = final_utts[-1] + cur_utt |
|
|
else: |
|
|
final_utts.append(cur_utt) |
|
|
|
|
|
return final_utts |
|
|
|
|
|
|
|
|
|
|
|
def replace_blank(text: str): |
|
|
out_str = [] |
|
|
for i, c in enumerate(text): |
|
|
if c == " ": |
|
|
if ((text[i + 1].isascii() and text[i + 1] != " ") and |
|
|
(text[i - 1].isascii() and text[i - 1] != " ")): |
|
|
out_str.append(c) |
|
|
else: |
|
|
out_str.append(c) |
|
|
return "".join(out_str) |
|
|
|
|
|
def clean_markdown(md_text: str) -> str: |
|
|
|
|
|
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL) |
|
|
|
|
|
|
|
|
md_text = re.sub(r"`[^`]*`", "", md_text) |
|
|
|
|
|
|
|
|
md_text = re.sub(r"!\[[^\]]*\]\([^\)]+\)", "", md_text) |
|
|
|
|
|
|
|
|
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text) |
|
|
|
|
|
|
|
|
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE) |
|
|
|
|
|
|
|
|
md_text = re.sub(r"<[^>]+>", "", md_text) |
|
|
|
|
|
|
|
|
md_text = re.sub(r"^#{1,6}\s*", "", md_text, flags=re.MULTILINE) |
|
|
|
|
|
|
|
|
md_text = re.sub(r"\n\s*\n", "\n", md_text) |
|
|
md_text = md_text.strip() |
|
|
|
|
|
return md_text |
|
|
|
|
|
|
|
|
def clean_text(text): |
|
|
|
|
|
text = clean_markdown(text) |
|
|
|
|
|
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text) |
|
|
|
|
|
text = text.replace("\n", " ") |
|
|
text = text.replace("\t", " ") |
|
|
text = text.replace('"', "\“") |
|
|
return text |
|
|
|
|
|
class TextNormalizer: |
|
|
def __init__(self, tokenizer=None): |
|
|
self.tokenizer = tokenizer |
|
|
self.zh_tn_model = Normalizer(lang="zh", operator="tn", remove_erhua=True) |
|
|
self.en_tn_model = Normalizer(lang="en", operator="tn") |
|
|
self.inflect_parser = inflect.engine() |
|
|
|
|
|
def normalize(self, text, split=False): |
|
|
|
|
|
lang = "zh" if contains_chinese(text) else "en" |
|
|
text = clean_text(text) |
|
|
if lang == "zh": |
|
|
text = text.replace("=", "等于") |
|
|
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): |
|
|
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) |
|
|
text = self.zh_tn_model.normalize(text) |
|
|
text = replace_blank(text) |
|
|
text = replace_corner_mark(text) |
|
|
text = remove_bracket(text) |
|
|
else: |
|
|
text = self.en_tn_model.normalize(text) |
|
|
text = spell_out_number(text, self.inflect_parser) |
|
|
if split is False: |
|
|
return text |