| import re |
| import jieba |
| import string |
| from pypinyin import pinyin, Style, load_phrases_dict |
| from pypinyin_dict.phrase_pinyin_data import cc_cedict |
| from g2p_en import G2p |
| import os |
| import json |
|
|
| |
| g2p = G2p() |
|
|
| |
| cc_cedict.load() |
|
|
| |
| poly_correct_path = "eval_pipeline/poly_correct.json" |
| if os.path.exists(poly_correct_path): |
| try: |
| with open(poly_correct_path, 'r', encoding='utf-8') as f: |
| poly_correct = json.load(f) |
| |
| |
| load_phrases_dict(poly_correct) |
| print(f"Loaded polyphone corrections from {poly_correct_path}") |
| except Exception as e: |
| print(f"Warning: Failed to load polyphone corrections: {e}") |
|
|
| re_special_pinyin = re.compile(r'^(n|ng|m)$') |
|
|
| |
| NUMBER_MAP = { |
| "0": "zero", "1": "one", "2": "two", "3": "three", "4": "four", |
| "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine", |
| } |
|
|
| chinese_punctuation_pattern = r'[\u3002\uff0c\uff1f\uff01\uff1b\uff1a\u201c\u201d\u2018\u2019\u300a\u300b\u3008\u3009\u3010\u3011\u300e\u300f\u2014\u2026\u3001\uff08\uff09]' |
|
|
| def _has_ch_punc(text): |
| match = re.search(chinese_punctuation_pattern, text) |
| return match is not None |
|
|
| def _has_en_punc(text): |
| return text in string.punctuation |
|
|
| def _split_py(py): |
| """Split pinyin with tone into initial (sm) and final (ym).""" |
| if not py: |
| return "", "" |
| |
| tone = "" |
| if py[-1].isdigit(): |
| tone = py[-1] |
| py = py[:-1] |
| |
| sm = "" |
| ym = "" |
| suf_r = "" |
| |
| if re_special_pinyin.match(py): |
| py = 'e' + py |
| if py.endswith('r'): |
| suf_r = 'r' |
| py = py[:-1] |
| |
| if py in ['zi', 'ci', 'si', 'ri']: |
| sm = py[:1] |
| ym = "ii" |
| elif py in ['zhi', 'chi', 'shi']: |
| sm = py[:2] |
| ym = "iii" |
| elif py in ['ya', 'yan', 'yang', 'yao', 'ye', 'yong', 'you']: |
| sm = "" |
| ym = 'i' + py[1:] |
| elif py in ['yi', 'yin', 'ying']: |
| sm = "" |
| ym = py[1:] |
| elif py in ['yu', 'yv', 'yuan', 'yvan', 'yue', 'yve', 'yun', 'yvn']: |
| sm = "" |
| ym = 'v' + py[2:] |
| elif py == 'wu': |
| sm = "" |
| ym = "u" |
| elif py.startswith('w'): |
| sm = "" |
| ym = "u" + py[1:] |
| elif len(py) >= 2 and py[0] in ['j', 'q', 'x'] and py[1] == 'u': |
| sm = py[0] |
| ym = 'v' + py[2:] |
| else: |
| seg_pos = re.search('a|e|i|o|u|v', py) |
| if seg_pos: |
| sm = py[:seg_pos.start()] |
| ym = py[seg_pos.start():] |
| if ym == 'ui': |
| ym = 'uei' |
| elif ym == 'iu': |
| ym = 'iou' |
| elif ym == 'un': |
| ym = 'uen' |
| elif ym == 'ue': |
| ym = 've' |
| else: |
| sm = ym = "" |
| return sm, ym |
| |
| ym += suf_r + tone |
| return sm, ym |
|
|
| def _trans_cn(text: str, with_sp=False): |
| """Convert Chinese text to phonemes.""" |
| phonemes = [] |
| |
| seg_list = jieba.cut(text) |
| |
| for seg in seg_list: |
| if seg.strip() == "": |
| continue |
| |
| |
| py_list = [_py[0] for _py in pinyin(seg, style=Style.TONE3, neutral_tone_with_five=True)] |
| |
| |
| if any(_has_ch_punc(_py) for _py in py_list) or any(_has_en_punc(_py) for _py in py_list): |
| continue |
| |
| for _py in py_list: |
| sm, ym = _split_py(_py) |
| if sm != "": |
| phonemes.append(sm) |
| if ym != "": |
| phonemes.append(ym) |
| |
| if len(phonemes) > 0 and with_sp: |
| phonemes.append("sp") |
| |
| return phonemes |
|
|
| def _trans_en(word: str, with_sp=False): |
| """Convert English word to phonemes using g2p_en.""" |
| |
| |
| |
| |
| |
| word_clean = word.strip(string.punctuation) |
| if not word_clean: |
| return [] |
| |
| phonemes = g2p(word_clean) |
| |
| phonemes = [p for p in phonemes if p.strip() and p not in string.punctuation] |
| |
| if len(phonemes) > 0 and with_sp: |
| phonemes.append("sp") |
| return phonemes |
|
|
| def _char_lang(c: str) -> int: |
| """ |
| 0 - Chinese |
| 1 - English |
| 2 - Digit |
| 3 - Other |
| """ |
| if '\u4e00' <= c <= '\u9fff': |
| return 0 |
| elif ('a' <= c <= 'z') or ('A' <= c <= 'Z'): |
| return 1 |
| elif c.isdigit(): |
| return 2 |
| else: |
| return 3 |
|
|
| def _lang_seperate(text: str): |
| """ |
| Separate text by language. |
| Replicating logic from meta_phonemes.py more closely. |
| """ |
| lang_segs = [] |
| lang_tags = [] |
| lang_seg = "" |
| lang_tag = -1 |
| en_count = 0 |
| |
| for c in text: |
| lang = _char_lang(c) |
| if lang_tag != lang: |
| |
| if lang_seg != "": |
| lang_segs.append(lang_seg) |
| lang_tags.append(lang_tag) |
| if lang_tag == 1: |
| en_count += 1 |
| lang_seg = "" |
| |
| |
| if lang == 2 and en_count >= 4: |
| |
| if c in NUMBER_MAP: |
| lang_segs.append(NUMBER_MAP[c]) |
| lang_tags.append(1) |
| |
| |
| |
| lang_tag = lang |
| |
| |
| else: |
| lang_tag = lang |
| |
| if lang < 2: |
| lang_seg += c |
| |
| |
| |
| if lang_seg != "": |
| lang_segs.append(lang_seg) |
| lang_tags.append(lang_tag) |
| |
| return lang_segs, lang_tags |
|
|
| def _phoneme_trans(text: str, with_sp=False): |
| """Convert text to phonemes (replicating meta_phonemes logic).""" |
| |
| lang_segs, lang_tags = _lang_seperate(text) |
| |
| phonemes = [] |
| for lang_seg, lang_tag in zip(lang_segs, lang_tags): |
| if lang_tag == 0: |
| |
| phonemes.extend(_trans_cn(lang_seg, with_sp)) |
| else: |
| |
| |
| |
| |
| |
| |
| |
| |
| words = lang_seg.split() |
| for word in words: |
| phonemes.extend(_trans_en(word, with_sp)) |
| |
| return phonemes |
|
|
| def get_phonemes(text: str, with_sp=False, remove_tones=True) -> list: |
| """ |
| Main entry point for phoneme conversion. |
| """ |
| phonemes = _phoneme_trans(text, with_sp) |
| |
| if remove_tones: |
| |
| phonemes = [re.sub(r'\d+', '', p) for p in phonemes] |
| |
| return phonemes |
|
|
| def calc_per(ref_phonemes, hyp_phonemes): |
| """ |
| Calculate Phoneme Error Rate (PER). |
| PER = (S + D + I) / N |
| Using Levenshtein distance. |
| """ |
| |
| |
| |
| if not ref_phonemes: |
| return 1.0 if hyp_phonemes else 0.0 |
| |
| return levenshtein_distance(ref_phonemes, hyp_phonemes) / len(ref_phonemes) |
|
|
| def levenshtein_distance(s1, s2): |
| if len(s1) < len(s2): |
| return levenshtein_distance(s2, s1) |
|
|
| if len(s2) == 0: |
| return len(s1) |
|
|
| previous_row = range(len(s2) + 1) |
| for i, c1 in enumerate(s1): |
| current_row = [i + 1] |
| for j, c2 in enumerate(s2): |
| insertions = previous_row[j + 1] + 1 |
| deletions = current_row[j] + 1 |
| substitutions = previous_row[j] + (c1 != c2) |
| current_row.append(min(insertions, deletions, substitutions)) |
| previous_row = current_row |
| |
| return previous_row[-1] |
|
|