realtime-tts / inference.py
drixo's picture
Update inference.py
90a21e7 verified
import torch
import torchaudio
from .model import RealtimeTTS
from .tokenizer import TTSTokenizer
from .config import TTSConfig
class TTSInference:
def __init__(self, model_path, tokenizer_path, device=None):
self.device = device or (
"cuda" if torch.cuda.is_available() else "cpu"
)
self.config = TTSConfig()
self.model = RealtimeTTS(self.config).to(self.device)
self.model.load_state_dict(
torch.load(model_path, map_location=self.device)
)
self.model.eval()
self.tokenizer = TTSTokenizer(tokenizer_path)
self.vocoder = (
torchaudio.pipelines.HIFIGAN_VOCODER_V3
.get_model()
.to(self.device)
)
@torch.no_grad()
def synthesize(self, text: str):
tokens = self.tokenizer.encode(text)
tokens = torch.tensor(tokens).unsqueeze(0).to(self.device)
mel_input = torch.zeros(
1, tokens.size(1), self.config.d_model
).to(self.device)
mel = self.model(tokens, mel_input)
audio = self.vocoder(mel.transpose(1, 2))
return audio.squeeze().cpu()