Spaces:
Runtime error
Runtime error
| import torch | |
| from pathlib import Path | |
| from .hparams import hparams as hp | |
| from .models.fatchord_version import WaveRNN | |
| from ...log import logger | |
| class WaveRNNVocoder: | |
| def __init__(self, model_path: Path): | |
| logger.debug("Building Wave-RNN") | |
| self._model = WaveRNN( | |
| rnn_dims=hp.voc_rnn_dims, | |
| fc_dims=hp.voc_fc_dims, | |
| bits=hp.bits, | |
| pad=hp.voc_pad, | |
| upsample_factors=hp.voc_upsample_factors, | |
| feat_dims=hp.num_mels, | |
| compute_dims=hp.voc_compute_dims, | |
| res_out_dims=hp.voc_res_out_dims, | |
| res_blocks=hp.voc_res_blocks, | |
| hop_length=hp.hop_length, | |
| sample_rate=hp.sample_rate, | |
| mode=hp.voc_mode, | |
| ) | |
| if torch.cuda.is_available(): | |
| self._model = self._model.cuda() | |
| self._device = torch.device("cuda") | |
| else: | |
| self._device = torch.device("cpu") | |
| logger.debug("Loading model weights at %s" % model_path) | |
| checkpoint = torch.load(model_path, self._device) | |
| self._model.load_state_dict(checkpoint["model_state"]) | |
| self._model.eval() | |
| def infer_waveform( | |
| self, mel, normalize=True, batched=True, target=8000, overlap=800 | |
| ): | |
| """ | |
| Infers the waveform of a mel spectrogram output by the synthesizer (the format must match | |
| that of the synthesizer!) | |
| :param normalize: | |
| :param batched: | |
| :param target: | |
| :param overlap: | |
| :return: | |
| """ | |
| if normalize: | |
| mel = mel / hp.mel_max_abs_value | |
| mel = torch.from_numpy(mel[None, ...]) | |
| wav = self._model.generate(mel, batched, target, overlap, hp.mu_law) | |
| return wav, hp.sample_rate | |