| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import re |
|
|
|
|
| def read_lists(list_file): |
| lists = [] |
| with open(list_file, 'r', encoding='utf8') as fin: |
| for line in fin: |
| lists.append(line.strip()) |
| return lists |
|
|
|
|
| def read_non_lang_symbols(non_lang_sym_path): |
| """read non-linguistic symbol from file. |
| |
| The file format is like below: |
| |
| {NOISE}\n |
| {BRK}\n |
| ... |
| |
| |
| Args: |
| non_lang_sym_path: non-linguistic symbol file path, None means no any |
| syms. |
| |
| """ |
| if non_lang_sym_path is None: |
| return [] |
| else: |
| syms = read_lists(non_lang_sym_path) |
| non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") |
| for sym in syms: |
| if non_lang_syms_pattern.fullmatch(sym) is None: |
|
|
| class BadSymbolFormat(Exception): |
| pass |
|
|
| raise BadSymbolFormat( |
| "Non-linguistic symbols should be " |
| "formatted in {xxx}/<xxx>/[xxx], consider" |
| " modify '%s' to meet the requirment. " |
| "More details can be found in discussions here : " |
| "https://github.com/wenet-e2e/wenet/pull/819" % (sym)) |
| return syms |
|
|
|
|
| def read_symbol_table(symbol_table_file): |
| symbol_table = {} |
| with open(symbol_table_file, 'r', encoding='utf8') as fin: |
| for line in fin: |
| arr = line.strip().split() |
| assert len(arr) == 2 |
| symbol_table[arr[0]] = int(arr[1]) |
| return symbol_table |
|
|