Spaces:
Runtime error
Runtime error
| import os | |
| from dataclasses import asdict | |
| from text import symbols | |
| import torch | |
| import torchaudio | |
| from utils.audio import LogMelSpectrogram | |
| from config import ModelConfig, VocosConfig, MelConfig | |
| from models.model import StableTTS | |
| from vocos_pytorch.models.model import Vocos | |
| from text.english import english_to_ipa2 | |
| from text import cleaned_text_to_sequence | |
| from datas.dataset import intersperse | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| device = 'cpu' | |
| def inference(text: str, ref_audio: torch.Tensor, checkpoint_path: str, step: int=10) -> torch.Tensor: | |
| global last_checkpoint_path | |
| if checkpoint_path != last_checkpoint_path: | |
| tts_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) | |
| last_checkpoint_path = checkpoint_path | |
| phonemizer = english_to_ipa2 | |
| # prepare input for tts model | |
| x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0) | |
| x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device) | |
| waveform, sr = torchaudio.load(ref_audio) | |
| if sr != sample_rate: | |
| waveform = torchaudio.functional.resample(waveform, sr, sample_rate) | |
| y = mel_extractor(waveform).to(device) | |
| # inference | |
| mel = tts_model.synthesise(x, x_len, step, y=y, temperature=0.667, length_scale=1)['decoder_outputs'] | |
| audio = vocoder(mel) | |
| # process output for gradio | |
| audio_output = (sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio | |
| mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel | |
| return audio_output, mel_output | |
| def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path, vocoder_checkpoint_path): | |
| tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config)) | |
| mel_extractor = LogMelSpectrogram(mel_config) | |
| vocoder = Vocos(vocoder_config, mel_config) | |
| # tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu')) | |
| tts_model.to(device) | |
| tts_model.eval() | |
| vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu')) | |
| vocoder.to(device) | |
| vocoder.eval() | |
| return tts_model, mel_extractor, vocoder | |
| def plot_mel_spectrogram(mel_spectrogram): | |
| fig, ax = plt.subplots(figsize=(20, 8)) | |
| ax.imshow(mel_spectrogram, aspect='auto', origin='lower') | |
| plt.axis('off') | |
| fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges | |
| return fig | |
| def main(): | |
| tts_model_config = ModelConfig() | |
| mel_config = MelConfig() | |
| vocoder_config = VocosConfig() | |
| tts_checkpoint_path = './checkpoints' # the folder that contains StableTTS checkpoints | |
| vocoder_checkpoint_path = './checkpoints/vocoder.pt' | |
| global tts_model, mel_extractor, vocoder, sample_rate, last_checkpoint_path | |
| sample_rate = mel_config.sample_rate | |
| last_checkpoint_path = None | |
| tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path) | |
| tts_checkpoint_path = [path for path in Path(tts_checkpoint_path).rglob('*.pt') if 'optimizer' and 'vocoder' not in path.name] | |
| audios = list(Path('./audios').rglob('*.wav')) + list(Path('./audios').rglob('*.flac')) | |
| # gradio wabui | |
| gui_title = 'StableTTS' | |
| gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3.""" | |
| with gr.Blocks(analytics_enabled=False) as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(f"# {gui_title}") | |
| gr.Markdown(gui_description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text_gr = gr.Textbox( | |
| label="Input Text", | |
| info="One or two sentences at a time is better. Up to 200 text characters.", | |
| value="Today I want to tell you three stories from my life. That's it. No big deal. Just three stories.", | |
| ) | |
| ref_audio_gr = gr.Dropdown( | |
| label='reference audio', | |
| choices=audios, | |
| value = 0 | |
| ) | |
| checkpoint_gr = gr.Dropdown( | |
| label='checkpoint', | |
| choices=tts_checkpoint_path, | |
| value = 0 | |
| ) | |
| step_gr = gr.Slider( | |
| label='Step', | |
| minimum=1, | |
| maximum=40, | |
| value=8, | |
| step=1 | |
| ) | |
| tts_button = gr.Button("Send", elem_id="send-btn", visible=True) | |
| with gr.Column(): | |
| mel_gr = gr.Plot(label="Mel Visual") | |
| audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True) | |
| tts_button.click(inference, [input_text_gr, ref_audio_gr, checkpoint_gr, step_gr], outputs=[audio_gr, mel_gr]) | |
| demo.queue() | |
| demo.launch(debug=True, show_api=True) | |
| if __name__ == '__main__': | |
| main() |