| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Credits |
| This code is modified from https://github.com/GitYCC/g2pW |
| """ |
| from typing import Dict |
| from typing import List |
| from typing import Tuple |
|
|
| import numpy as np |
|
|
| from .utils import tokenize_and_map |
|
|
| ANCHOR_CHAR = '▁' |
|
|
|
|
| def prepare_onnx_input(tokenizer, |
| labels: List[str], |
| char2phonemes: Dict[str, List[int]], |
| chars: List[str], |
| texts: List[str], |
| query_ids: List[int], |
| use_mask: bool=False, |
| window_size: int=None, |
| max_len: int=512) -> Dict[str, np.array]: |
| if window_size is not None: |
| truncated_texts, truncated_query_ids = _truncate_texts( |
| window_size=window_size, texts=texts, query_ids=query_ids) |
| input_ids = [] |
| token_type_ids = [] |
| attention_masks = [] |
| phoneme_masks = [] |
| char_ids = [] |
| position_ids = [] |
|
|
| for idx in range(len(texts)): |
| text = (truncated_texts if window_size else texts)[idx].lower() |
| query_id = (truncated_query_ids if window_size else query_ids)[idx] |
|
|
| try: |
| tokens, text2token, token2text = tokenize_and_map( |
| tokenizer=tokenizer, text=text) |
| except Exception: |
| print(f'warning: text "{text}" is invalid') |
| return {} |
|
|
| text, query_id, tokens, text2token, token2text = _truncate( |
| max_len=max_len, |
| text=text, |
| query_id=query_id, |
| tokens=tokens, |
| text2token=text2token, |
| token2text=token2text) |
|
|
| processed_tokens = ['[CLS]'] + tokens + ['[SEP]'] |
|
|
| input_id = list( |
| np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) |
| token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int)) |
| attention_mask = list(np.ones((len(processed_tokens), ), dtype=int)) |
|
|
| query_char = text[query_id] |
| phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \ |
| if use_mask else [1] * len(labels) |
| char_id = chars.index(query_char) |
| position_id = text2token[ |
| query_id] + 1 |
|
|
| input_ids.append(input_id) |
| token_type_ids.append(token_type_id) |
| attention_masks.append(attention_mask) |
| phoneme_masks.append(phoneme_mask) |
| char_ids.append(char_id) |
| position_ids.append(position_id) |
|
|
| outputs = { |
| 'input_ids': np.array(input_ids).astype(np.int64), |
| 'token_type_ids': np.array(token_type_ids).astype(np.int64), |
| 'attention_masks': np.array(attention_masks).astype(np.int64), |
| 'phoneme_masks': np.array(phoneme_masks).astype(np.float32), |
| 'char_ids': np.array(char_ids).astype(np.int64), |
| 'position_ids': np.array(position_ids).astype(np.int64), |
| } |
| return outputs |
|
|
|
|
| def _truncate_texts(window_size: int, texts: List[str], |
| query_ids: List[int]) -> Tuple[List[str], List[int]]: |
| truncated_texts = [] |
| truncated_query_ids = [] |
| for text, query_id in zip(texts, query_ids): |
| start = max(0, query_id - window_size // 2) |
| end = min(len(text), query_id + window_size // 2) |
| truncated_text = text[start:end] |
| truncated_texts.append(truncated_text) |
|
|
| truncated_query_id = query_id - start |
| truncated_query_ids.append(truncated_query_id) |
| return truncated_texts, truncated_query_ids |
|
|
|
|
| def _truncate(max_len: int, |
| text: str, |
| query_id: int, |
| tokens: List[str], |
| text2token: List[int], |
| token2text: List[Tuple[int]]): |
| truncate_len = max_len - 2 |
| if len(tokens) <= truncate_len: |
| return (text, query_id, tokens, text2token, token2text) |
|
|
| token_position = text2token[query_id] |
|
|
| token_start = token_position - truncate_len // 2 |
| token_end = token_start + truncate_len |
| font_exceed_dist = -token_start |
| back_exceed_dist = token_end - len(tokens) |
| if font_exceed_dist > 0: |
| token_start += font_exceed_dist |
| token_end += font_exceed_dist |
| elif back_exceed_dist > 0: |
| token_start -= back_exceed_dist |
| token_end -= back_exceed_dist |
|
|
| start = token2text[token_start][0] |
| end = token2text[token_end - 1][1] |
|
|
| return (text[start:end], query_id - start, tokens[token_start:token_end], [ |
| i - token_start if i is not None else None |
| for i in text2token[start:end] |
| ], [(s - start, e - start) for s, e in token2text[token_start:token_end]]) |
|
|
|
|
| def get_phoneme_labels(polyphonic_chars: List[List[str]] |
| ) -> Tuple[List[str], Dict[str, List[int]]]: |
| labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars]))) |
| char2phonemes = {} |
| for char, phoneme in polyphonic_chars: |
| if char not in char2phonemes: |
| char2phonemes[char] = [] |
| char2phonemes[char].append(labels.index(phoneme)) |
| return labels, char2phonemes |
|
|
|
|
| def get_char_phoneme_labels(polyphonic_chars: List[List[str]] |
| ) -> Tuple[List[str], Dict[str, List[int]]]: |
| labels = sorted( |
| list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars]))) |
| char2phonemes = {} |
| for char, phoneme in polyphonic_chars: |
| if char not in char2phonemes: |
| char2phonemes[char] = [] |
| char2phonemes[char].append(labels.index(f'{char} {phoneme}')) |
| return labels, char2phonemes |
|
|