Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import audiosr.hifigan as hifigan | |
| def get_vocoder_config(): | |
| return { | |
| "resblock": "1", | |
| "num_gpus": 6, | |
| "batch_size": 16, | |
| "learning_rate": 0.0002, | |
| "adam_b1": 0.8, | |
| "adam_b2": 0.99, | |
| "lr_decay": 0.999, | |
| "seed": 1234, | |
| "upsample_rates": [5, 4, 2, 2, 2], | |
| "upsample_kernel_sizes": [16, 16, 8, 4, 4], | |
| "upsample_initial_channel": 1024, | |
| "resblock_kernel_sizes": [3, 7, 11], | |
| "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| "segment_size": 8192, | |
| "num_mels": 64, | |
| "num_freq": 1025, | |
| "n_fft": 1024, | |
| "hop_size": 160, | |
| "win_size": 1024, | |
| "sampling_rate": 16000, | |
| "fmin": 0, | |
| "fmax": 8000, | |
| "fmax_for_loss": None, | |
| "num_workers": 4, | |
| "dist_config": { | |
| "dist_backend": "nccl", | |
| "dist_url": "tcp://localhost:54321", | |
| "world_size": 1, | |
| }, | |
| } | |
| def get_vocoder_config_48k(): | |
| return { | |
| "resblock": "1", | |
| "num_gpus": 8, | |
| "batch_size": 128, | |
| "learning_rate": 0.0001, | |
| "adam_b1": 0.8, | |
| "adam_b2": 0.99, | |
| "lr_decay": 0.999, | |
| "seed": 1234, | |
| "upsample_rates": [6, 5, 4, 2, 2], | |
| "upsample_kernel_sizes": [12, 10, 8, 4, 4], | |
| "upsample_initial_channel": 1536, | |
| "resblock_kernel_sizes": [3, 7, 11, 15], | |
| "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| "segment_size": 15360, | |
| "num_mels": 256, | |
| "n_fft": 2048, | |
| "hop_size": 480, | |
| "win_size": 2048, | |
| "sampling_rate": 48000, | |
| "fmin": 20, | |
| "fmax": 24000, | |
| "fmax_for_loss": None, | |
| "num_workers": 8, | |
| "dist_config": { | |
| "dist_backend": "nccl", | |
| "dist_url": "tcp://localhost:18273", | |
| "world_size": 1, | |
| }, | |
| } | |
| def get_available_checkpoint_keys(model, ckpt): | |
| state_dict = torch.load(ckpt)["state_dict"] | |
| current_state_dict = model.state_dict() | |
| new_state_dict = {} | |
| for k in state_dict.keys(): | |
| if ( | |
| k in current_state_dict.keys() | |
| and current_state_dict[k].size() == state_dict[k].size() | |
| ): | |
| new_state_dict[k] = state_dict[k] | |
| else: | |
| print("==> WARNING: Skipping %s" % k) | |
| print( | |
| "%s out of %s keys are matched" | |
| % (len(new_state_dict.keys()), len(state_dict.keys())) | |
| ) | |
| return new_state_dict | |
| def get_param_num(model): | |
| num_param = sum(param.numel() for param in model.parameters()) | |
| return num_param | |
| def torch_version_orig_mod_remove(state_dict): | |
| new_state_dict = {} | |
| new_state_dict["generator"] = {} | |
| for key in state_dict["generator"].keys(): | |
| if "_orig_mod." in key: | |
| new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[ | |
| "generator" | |
| ][key] | |
| else: | |
| new_state_dict["generator"][key] = state_dict["generator"][key] | |
| return new_state_dict | |
| def get_vocoder(config, device, mel_bins): | |
| name = "HiFi-GAN" | |
| speaker = "" | |
| if name == "MelGAN": | |
| if speaker == "LJSpeech": | |
| vocoder = torch.hub.load( | |
| "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" | |
| ) | |
| elif speaker == "universal": | |
| vocoder = torch.hub.load( | |
| "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" | |
| ) | |
| vocoder.mel2wav.eval() | |
| vocoder.mel2wav.to(device) | |
| elif name == "HiFi-GAN": | |
| if mel_bins == 64: | |
| config = get_vocoder_config() | |
| config = hifigan.AttrDict(config) | |
| vocoder = hifigan.Generator_old(config) | |
| # print("Load hifigan/g_01080000") | |
| # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) | |
| # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) | |
| # ckpt = torch_version_orig_mod_remove(ckpt) | |
| # vocoder.load_state_dict(ckpt["generator"]) | |
| vocoder.eval() | |
| vocoder.remove_weight_norm() | |
| vocoder.to(device) | |
| else: | |
| config = get_vocoder_config_48k() | |
| config = hifigan.AttrDict(config) | |
| vocoder = hifigan.Generator_old(config) | |
| # print("Load hifigan/g_01080000") | |
| # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) | |
| # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) | |
| # ckpt = torch_version_orig_mod_remove(ckpt) | |
| # vocoder.load_state_dict(ckpt["generator"]) | |
| vocoder.eval() | |
| vocoder.remove_weight_norm() | |
| vocoder.to(device) | |
| return vocoder | |
| def vocoder_infer(mels, vocoder, lengths=None): | |
| with torch.no_grad(): | |
| wavs = vocoder(mels).squeeze(1) | |
| wavs = (wavs.cpu().numpy() * 32768).astype("int16") | |
| if lengths is not None: | |
| wavs = wavs[:, :lengths] | |
| # wavs = [wav for wav in wavs] | |
| # for i in range(len(mels)): | |
| # if lengths is not None: | |
| # wavs[i] = wavs[i][: lengths[i]] | |
| return wavs | |