Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import numpy as np | |
| from examples.textless_nlp.gslm.unit2speech.tacotron2.text import ( | |
| EOS_TOK, | |
| SOS_TOK, | |
| code_to_sequence, | |
| text_to_sequence, | |
| ) | |
| from examples.textless_nlp.gslm.unit2speech.tacotron2.utils import ( | |
| load_code_dict, | |
| ) | |
| class TacotronInputDataset: | |
| def __init__(self, hparams, append_str=""): | |
| self.is_text = getattr(hparams, "text_or_code", "text") == "text" | |
| if not self.is_text: | |
| self.code_dict = load_code_dict(hparams.code_dict) | |
| self.code_key = hparams.code_key | |
| self.add_sos = hparams.add_sos | |
| self.add_eos = hparams.add_eos | |
| self.collapse_code = hparams.collapse_code | |
| self.append_str = append_str | |
| def process_code(self, inp_str): | |
| inp_toks = inp_str.split() | |
| if self.add_sos: | |
| inp_toks = [SOS_TOK] + inp_toks | |
| if self.add_eos: | |
| inp_toks = inp_toks + [EOS_TOK] | |
| return code_to_sequence(inp_toks, self.code_dict, self.collapse_code) | |
| def process_text(self, inp_str): | |
| return text_to_sequence(inp_str, ["english_cleaners"]) | |
| def get_tensor(self, inp_str): | |
| # uid, txt, inp_str = self._get_data(idx) | |
| inp_str = inp_str + self.append_str | |
| if self.is_text: | |
| inp_toks = self.process_text(inp_str) | |
| else: | |
| inp_toks = self.process_code(inp_str) | |
| return torch.from_numpy(np.array(inp_toks)).long() | |
| def __len__(self): | |
| return len(self.data) | |