Spaces:
Paused
Paused
| import os | |
| import torch | |
| import torchaudio | |
| import text | |
| from . import get_custom_config, get_config | |
| from vocoder import load_hifigan | |
| from vocoder.hifigan.denoiser import Denoiser | |
| config = get_config('./configs/basic.yaml') | |
| models_config = get_custom_config('./app/models.yaml') | |
| def load_models(): | |
| models = [] | |
| for model_name, model_dict in models_config.__dict__.items(): | |
| sd_path = model_dict['path'] | |
| if not os.path.exists(sd_path): | |
| print(f"No model @ {sd_path}") | |
| continue | |
| if model_dict['type'] == 'tacotron2': | |
| from models.tacotron2 import Tacotron2 | |
| model = Tacotron2(sd_path) | |
| elif model_dict['type'] == 'fastpitch': | |
| from models.fastpitch import FastPitch | |
| model = FastPitch(sd_path) | |
| else: | |
| print(f"Model type: {model_dict['type']} not supported") | |
| continue | |
| models.append((model_name, model)) | |
| return models | |
| class TTSManager: | |
| def __init__(self, out_dir, | |
| use_cuda_if_available = True, | |
| sample_rate = 22_050): | |
| if not os.path.exists(out_dir): | |
| os.makedirs(out_dir) | |
| print(f"Created folder: {out_dir}") | |
| device = torch.device( | |
| 'cuda' if torch.cuda.is_available() and use_cuda_if_available else 'cpu') | |
| self.vocoder = load_hifigan(config.vocoder_state_path, | |
| config.vocoder_config_path) | |
| self.denoiser = Denoiser(self.vocoder, mode='zeros') | |
| self.vocoder.to(device) | |
| self.denoiser.to(device) | |
| self.sample_rate = sample_rate | |
| self.models = load_models() | |
| self.out_dir = out_dir | |
| self.device = device | |
| def tts(self, text_buckw, speed=1, denoise=0.01): | |
| response_data = [] | |
| for i, (model_name, model) in enumerate(self.models): | |
| model.to(self.device) | |
| mel_spec = model.ttmel(text_buckw, speed=speed) | |
| wave = self.vocoder(mel_spec) | |
| wave_den = self.denoiser(wave, denoise) | |
| wave_den /= wave_den.abs().max() | |
| wave_den *= 0.99 | |
| torchaudio.save(f'./app/static/wave{i}.wav', | |
| wave_den.cpu(), self.sample_rate) | |
| response_data.append({ | |
| 'name': model_name, | |
| 'phon': '', | |
| 'id': i, | |
| }) | |
| model.cpu() | |
| return response_data | |