| import importlib
|
| VOCODERS = {}
|
|
|
|
|
| def register_vocoder(cls):
|
| VOCODERS[cls.__name__.lower()] = cls
|
| VOCODERS[cls.__name__] = cls
|
| return cls
|
|
|
|
|
| def get_vocoder_cls(hparams):
|
| if hparams['vocoder'] in VOCODERS:
|
| return VOCODERS[hparams['vocoder']]
|
| else:
|
| vocoder_cls = hparams['vocoder']
|
| pkg = ".".join(vocoder_cls.split(".")[:-1])
|
| cls_name = vocoder_cls.split(".")[-1]
|
| vocoder_cls = getattr(importlib.import_module(pkg), cls_name)
|
| return vocoder_cls
|
|
|
|
|
| class BaseVocoder:
|
| def spec2wav(self, mel):
|
| """
|
|
|
| :param mel: [T, 80]
|
| :return: wav: [T']
|
| """
|
|
|
| raise NotImplementedError
|
|
|
| @staticmethod
|
| def wav2spec(wav_fn):
|
| """
|
|
|
| :param wav_fn: str
|
| :return: wav, mel: [T, 80]
|
| """
|
| raise NotImplementedError
|
|
|