| | |
| | |
| | |
| | |
| | |
| | from pathlib import Path |
| | from typing import List, Tuple |
| | import os |
| | import numpy as np |
| | import torch |
| | from text.symbol_table import SymbolTable |
| | from text import text_to_sequence |
| |
|
| |
|
| | ''' |
| | TextToken: map text to id |
| | ''' |
| | |
| | |
| | class TextTokenCollator: |
| | def __init__( |
| | self, |
| | text_tokens: List[str], |
| | add_eos: bool = True, |
| | add_bos: bool = True, |
| | pad_symbol: str = "<pad>", |
| | bos_symbol: str = "<bos>", |
| | eos_symbol: str = "<eos>", |
| | ): |
| | self.pad_symbol = pad_symbol |
| | self.add_eos = add_eos |
| | self.add_bos = add_bos |
| | self.bos_symbol = bos_symbol |
| | self.eos_symbol = eos_symbol |
| |
|
| | unique_tokens = [pad_symbol] |
| | if add_bos: |
| | unique_tokens.append(bos_symbol) |
| | if add_eos: |
| | unique_tokens.append(eos_symbol) |
| | unique_tokens.extend(sorted(text_tokens)) |
| |
|
| | self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} |
| | self.idx2token = unique_tokens |
| |
|
| | def index( |
| | self, tokens_list: List[str] |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | seqs, seq_lens = [], [] |
| | for tokens in tokens_list: |
| | assert ( |
| | all([True if s in self.token2idx else False for s in tokens]) |
| | is True |
| | ) |
| | seq = ( |
| | ([self.bos_symbol] if self.add_bos else []) |
| | + list(tokens) |
| | + ([self.eos_symbol] if self.add_eos else []) |
| | ) |
| | seqs.append(seq) |
| | seq_lens.append(len(seq)) |
| |
|
| | max_len = max(seq_lens) |
| | for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): |
| | seq.extend([self.pad_symbol] * (max_len - seq_len)) |
| |
|
| | tokens = torch.from_numpy( |
| | np.array( |
| | [[self.token2idx[token] for token in seq] for seq in seqs], |
| | dtype=np.int64, |
| | ) |
| | ) |
| | tokens_lens = torch.IntTensor(seq_lens) |
| |
|
| | return tokens, tokens_lens |
| | |
| | def __call__(self, text): |
| | tokens_seq = [p for p in text] |
| | seq = ( |
| | ([self.bos_symbol] if self.add_bos else []) |
| | + tokens_seq |
| | + ([self.eos_symbol] if self.add_eos else []) |
| | ) |
| |
|
| | token_ids = [self.token2idx[token] for token in seq] |
| | token_lens = len(tokens_seq) + self.add_eos + self.add_bos |
| |
|
| | return token_ids, token_lens |
| | |
| |
|
| | def get_text_token_collater(text_tokens_file: str) -> TextTokenCollator: |
| | text_tokens_path = Path(text_tokens_file) |
| | unique_tokens = SymbolTable.from_file(text_tokens_path) |
| | collater = TextTokenCollator( |
| | unique_tokens.symbols, add_bos=True, add_eos=True |
| | ) |
| | token2idx = collater.token2idx |
| | return collater, token2idx |
| |
|
| |
|
| | class phoneIDCollation: |
| | def __init__(self, cfg, dataset=None, symbols_dict_file=None) -> None: |
| |
|
| | if cfg.preprocess.phone_extractor != 'lexicon': |
| | |
| | if symbols_dict_file is None: |
| | assert dataset is not None |
| | symbols_dict_file = os.path.join( |
| | cfg.preprocess.processed_dir, |
| | dataset, |
| | cfg.preprocess.symbols_dict |
| | ) |
| | self.text_token_colloator, token2idx = get_text_token_collater(symbols_dict_file) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def get_phone_id_sequence(self, cfg, phones_seq): |
| | |
| | if cfg.preprocess.phone_extractor == 'lexicon': |
| | phones_seq = ' '.join(phones_seq) |
| | sequence = text_to_sequence(phones_seq, cfg.preprocess.text_cleaners) |
| | else: |
| | sequence, seq_len = self.text_token_colloator(phones_seq) |
| | return sequence |
| | |