Spaces:
Runtime error
Runtime error
| import os | |
| import pickle | |
| import timeit | |
| import numpy as np | |
| import onnxruntime as ort | |
| from nix.tokenizers.tokenizer_en import NixTokenizerEN | |
| class NixTTSInference: | |
| def __init__( | |
| self, | |
| model_dir, | |
| ): | |
| # Load tokenizer | |
| self.tokenizer = NixTokenizerEN(pickle.load(open(os.path.join(model_dir, "tokenizer_state.pkl"), "rb"))) | |
| # Load TTS model | |
| self.encoder = ort.InferenceSession(os.path.join(model_dir, "encoder.onnx")) | |
| self.decoder = ort.InferenceSession(os.path.join(model_dir, "decoder.onnx")) | |
| def tokenize( | |
| self, | |
| text, | |
| ): | |
| # Tokenize input text | |
| c, c_lengths, phonemes = self.tokenizer([text]) | |
| return np.array(c, dtype = np.int64), np.array(c_lengths, dtype = np.int64), phonemes | |
| def vocalize( | |
| self, | |
| c, | |
| c_lengths, | |
| ): | |
| """ | |
| Single-batch TTS inference | |
| """ | |
| # Infer latent samples from encoder | |
| z = self.encoder.run( | |
| None, | |
| { | |
| "c": c, | |
| "c_lengths": c_lengths, | |
| } | |
| )[2] | |
| # Decode raw audio with decoder | |
| xw = self.decoder.run( | |
| None, | |
| { | |
| "z": z, | |
| } | |
| )[0] | |
| return xw | |