| import torch |
| import torchaudio |
|
|
| from .model import RealtimeTTS |
| from .tokenizer import TTSTokenizer |
| from .config import TTSConfig |
|
|
|
|
| class TTSInference: |
| def __init__(self, model_path, tokenizer_path, device=None): |
| self.device = device or ( |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
|
|
| self.config = TTSConfig() |
|
|
| self.model = RealtimeTTS(self.config).to(self.device) |
| self.model.load_state_dict( |
| torch.load(model_path, map_location=self.device) |
| ) |
| self.model.eval() |
|
|
| self.tokenizer = TTSTokenizer(tokenizer_path) |
|
|
| self.vocoder = ( |
| torchaudio.pipelines.HIFIGAN_VOCODER_V3 |
| .get_model() |
| .to(self.device) |
| ) |
|
|
| @torch.no_grad() |
| def synthesize(self, text: str): |
| tokens = self.tokenizer.encode(text) |
| tokens = torch.tensor(tokens).unsqueeze(0).to(self.device) |
|
|
| mel_input = torch.zeros( |
| 1, tokens.size(1), self.config.d_model |
| ).to(self.device) |
|
|
| mel = self.model(tokens, mel_input) |
|
|
| audio = self.vocoder(mel.transpose(1, 2)) |
|
|
| return audio.squeeze().cpu() |
|
|