| |
| |
|
|
| from typing import Dict, List |
|
|
| import onnxruntime |
| import soundfile |
| import torch |
|
|
|
|
| def display(sess): |
| for i in sess.get_inputs(): |
| print(i) |
|
|
| print("-" * 10) |
| for o in sess.get_outputs(): |
| print(o) |
|
|
|
|
| class OnnxModel: |
| def __init__( |
| self, |
| model: str, |
| ): |
| session_opts = onnxruntime.SessionOptions() |
| session_opts.inter_op_num_threads = 1 |
| session_opts.intra_op_num_threads = 4 |
|
|
| self.session_opts = session_opts |
|
|
| self.model = onnxruntime.InferenceSession( |
| model, |
| sess_options=self.session_opts, |
| ) |
| display(self.model) |
|
|
| meta = self.model.get_modelmeta().custom_metadata_map |
| self.add_blank = int(meta["add_blank"]) |
| self.sample_rate = int(meta["sample_rate"]) |
| self.punctuation = meta["punctuation"].split() |
| print(meta) |
|
|
| def __call__( |
| self, |
| x: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: |
| A int64 tensor of shape (L,) |
| """ |
| x = x.unsqueeze(0) |
| x_length = torch.tensor([x.shape[1]], dtype=torch.int64) |
| noise_scale = torch.tensor([1], dtype=torch.float32) |
| length_scale = torch.tensor([1], dtype=torch.float32) |
| noise_scale_w = torch.tensor([1], dtype=torch.float32) |
|
|
| y = self.model.run( |
| [ |
| self.model.get_outputs()[0].name, |
| ], |
| { |
| self.model.get_inputs()[0].name: x.numpy(), |
| self.model.get_inputs()[1].name: x_length.numpy(), |
| self.model.get_inputs()[2].name: noise_scale.numpy(), |
| self.model.get_inputs()[3].name: length_scale.numpy(), |
| self.model.get_inputs()[4].name: noise_scale_w.numpy(), |
| }, |
| )[0] |
| return torch.from_numpy(y).squeeze() |
|
|
|
|
| def read_lexicon() -> Dict[str, List[str]]: |
| ans = dict() |
| with open("./lexicon.txt", encoding="utf-8") as f: |
| for line in f: |
| w_p = line.split() |
| w = w_p[0] |
| p = w_p[1:] |
| ans[w] = p |
| return ans |
|
|
|
|
| def read_tokens() -> Dict[str, int]: |
| ans = dict() |
| with open("./tokens.txt", encoding="utf-8") as f: |
| for line in f: |
| t_i = line.strip().split() |
| if len(t_i) == 1: |
| token = " " |
| idx = t_i[0] |
| else: |
| assert len(t_i) == 2, (t_i, line) |
| token = t_i[0] |
| idx = t_i[1] |
| ans[token] = int(idx) |
| return ans |
|
|
|
|
| def convert_lexicon(lexicon, tokens): |
| for w in lexicon: |
| phones = lexicon[w] |
| try: |
| p = [tokens[i] for i in phones] |
| lexicon[w] = p |
| except Exception: |
| |
| continue |
|
|
|
|
| """ |
| skip rapprochement |
| skip croissants |
| skip aix-en-provence |
| skip provence |
| skip croissant |
| skip denouement |
| skip hola |
| skip blanc |
| """ |
|
|
|
|
| def get_text(text, lexicon, tokens, punctuation): |
| text = text.lower().split() |
| ans = [] |
| for i in range(len(text)): |
| w = text[i] |
| punct = None |
|
|
| if w[0] in punctuation: |
| ans.append(tokens[w[0]]) |
| w = w[1:] |
|
|
| if w[-1] in punctuation: |
| punct = tokens[w[-1]] |
| w = w[:-1] |
|
|
| if w in lexicon: |
| ans.extend(lexicon[w]) |
| if punct: |
| ans.append(punct) |
|
|
| if i != len(text) - 1: |
| ans.append(tokens[" "]) |
| continue |
| print("ignore", w) |
| return ans |
|
|
|
|
| def generate(model, text, lexicon, tokens): |
| x = get_text( |
| text, |
| lexicon, |
| tokens, |
| model.punctuation, |
| ) |
| if model.add_blank: |
| x2 = [0] * (2 * len(x) + 1) |
| x2[1::2] = x |
| x = x2 |
|
|
| x = torch.tensor(x, dtype=torch.int64) |
|
|
| y = model(x) |
|
|
| return y |
|
|
|
|
| def main(): |
| model = OnnxModel("./vits-ljs.onnx") |
|
|
| lexicon = read_lexicon() |
| tokens = read_tokens() |
| convert_lexicon(lexicon, tokens) |
|
|
| text = "Liliana, our most beautiful and lovely assistant" |
| y = generate(model, text, lexicon, tokens) |
| soundfile.write("test-0.wav", y.numpy(), model.sample_rate) |
|
|
| text = "Ask not what your country can do for you; ask what you can do for your country." |
| y = generate(model, text, lexicon, tokens) |
| soundfile.write("test-1.wav", y.numpy(), model.sample_rate) |
|
|
| text = "Success is not final, failure is not fatal, it is the courage to continue that counts!" |
| y = generate(model, text, lexicon, tokens) |
| soundfile.write("test-2.wav", y.numpy(), model.sample_rate) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|