| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
|
|
| import torch |
| import torch.nn as nn |
| from conv_stft import STFT |
| from huggingface_hub import hf_hub_download |
| from vocos import Vocos |
|
|
|
|
| opset_version = 17 |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| parser.add_argument( |
| "--vocoder", |
| type=str, |
| default="vocos", |
| choices=["vocos", "bigvgan"], |
| help="Vocoder to export", |
| ) |
| parser.add_argument( |
| "--output-path", |
| type=str, |
| default="./vocos_vocoder.onnx", |
| help="Output path", |
| ) |
| return parser.parse_args() |
|
|
|
|
| class ISTFTHead(nn.Module): |
| def __init__(self, n_fft: int, hop_length: int): |
| super().__init__() |
| self.out = None |
| self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft) |
|
|
| def forward(self, x: torch.Tensor): |
| x = self.out(x).transpose(1, 2) |
| mag, p = x.chunk(2, dim=1) |
| mag = torch.exp(mag) |
| mag = torch.clip(mag, max=1e2) |
| real = mag * torch.cos(p) |
| imag = mag * torch.sin(p) |
| audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag") |
| return audio |
|
|
|
|
| class VocosVocoder(nn.Module): |
| def __init__(self, vocos_vocoder): |
| super(VocosVocoder, self).__init__() |
| self.vocos_vocoder = vocos_vocoder |
| istft_head_out = self.vocos_vocoder.head.out |
| n_fft = self.vocos_vocoder.head.istft.n_fft |
| hop_length = self.vocos_vocoder.head.istft.hop_length |
| istft_head_for_export = ISTFTHead(n_fft, hop_length) |
| istft_head_for_export.out = istft_head_out |
| self.vocos_vocoder.head = istft_head_for_export |
|
|
| def forward(self, mel): |
| waveform = self.vocos_vocoder.decode(mel) |
| return waveform |
|
|
|
|
| def export_VocosVocoder(vocos_vocoder, output_path, verbose): |
| vocos_vocoder = VocosVocoder(vocos_vocoder).cuda() |
| vocos_vocoder.eval() |
|
|
| dummy_batch_size = 8 |
| dummy_input_length = 500 |
|
|
| dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda() |
|
|
| with torch.no_grad(): |
| dummy_waveform = vocos_vocoder(mel=dummy_mel) |
| print(dummy_waveform.shape) |
|
|
| dummy_input = dummy_mel |
|
|
| torch.onnx.export( |
| vocos_vocoder, |
| dummy_input, |
| output_path, |
| opset_version=opset_version, |
| do_constant_folding=True, |
| input_names=["mel"], |
| output_names=["waveform"], |
| dynamic_axes={ |
| "mel": {0: "batch_size", 2: "input_length"}, |
| "waveform": {0: "batch_size", 1: "output_length"}, |
| }, |
| verbose=verbose, |
| ) |
|
|
| print("Exported to {}".format(output_path)) |
|
|
|
|
| def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None): |
| if vocoder_name == "vocos": |
| |
| if is_local: |
| print(f"Load vocos from local path {local_path}") |
| config_path = f"{local_path}/config.yaml" |
| model_path = f"{local_path}/pytorch_model.bin" |
| else: |
| print("Download Vocos from huggingface charactr/vocos-mel-24khz") |
| repo_id = "charactr/vocos-mel-24khz" |
| config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") |
| model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") |
| vocoder = Vocos.from_hparams(config_path) |
| state_dict = torch.load(model_path, map_location="cpu", weights_only=True) |
| vocoder.load_state_dict(state_dict) |
| vocoder = vocoder.eval().to(device) |
| elif vocoder_name == "bigvgan": |
| raise NotImplementedError("BigVGAN is not supported yet") |
| vocoder.remove_weight_norm() |
| vocoder = vocoder.eval().to(device) |
| return vocoder |
|
|
|
|
| if __name__ == "__main__": |
| args = get_args() |
| vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None) |
| if args.vocoder == "vocos": |
| export_VocosVocoder(vocoder, args.output_path, verbose=False) |