| import torch |
|
|
| import FlashSR.AudioSR.hifigan as hifigan |
|
|
|
|
| def get_vocoder_config(): |
| return { |
| "resblock": "1", |
| "num_gpus": 6, |
| "batch_size": 16, |
| "learning_rate": 0.0002, |
| "adam_b1": 0.8, |
| "adam_b2": 0.99, |
| "lr_decay": 0.999, |
| "seed": 1234, |
| "upsample_rates": [5, 4, 2, 2, 2], |
| "upsample_kernel_sizes": [16, 16, 8, 4, 4], |
| "upsample_initial_channel": 1024, |
| "resblock_kernel_sizes": [3, 7, 11], |
| "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], |
| "segment_size": 8192, |
| "num_mels": 64, |
| "num_freq": 1025, |
| "n_fft": 1024, |
| "hop_size": 160, |
| "win_size": 1024, |
| "sampling_rate": 16000, |
| "fmin": 0, |
| "fmax": 8000, |
| "fmax_for_loss": None, |
| "num_workers": 4, |
| "dist_config": { |
| "dist_backend": "nccl", |
| "dist_url": "tcp://localhost:54321", |
| "world_size": 1, |
| }, |
| } |
|
|
|
|
| def get_vocoder_config_48k(): |
| return { |
| "resblock": "1", |
| "num_gpus": 8, |
| "batch_size": 128, |
| "learning_rate": 0.0001, |
| "adam_b1": 0.8, |
| "adam_b2": 0.99, |
| "lr_decay": 0.999, |
| "seed": 1234, |
| "upsample_rates": [6, 5, 4, 2, 2], |
| "upsample_kernel_sizes": [12, 10, 8, 4, 4], |
| "upsample_initial_channel": 1536, |
| "resblock_kernel_sizes": [3, 7, 11, 15], |
| "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]], |
| "segment_size": 15360, |
| "num_mels": 256, |
| "n_fft": 2048, |
| "hop_size": 480, |
| "win_size": 2048, |
| "sampling_rate": 48000, |
| "fmin": 20, |
| "fmax": 24000, |
| "fmax_for_loss": None, |
| "num_workers": 8, |
| "dist_config": { |
| "dist_backend": "nccl", |
| "dist_url": "tcp://localhost:18273", |
| "world_size": 1, |
| }, |
| } |
|
|
|
|
| def get_available_checkpoint_keys(model, ckpt): |
| state_dict = torch.load(ckpt)["state_dict"] |
| current_state_dict = model.state_dict() |
| new_state_dict = {} |
| for k in state_dict.keys(): |
| if ( |
| k in current_state_dict.keys() |
| and current_state_dict[k].size() == state_dict[k].size() |
| ): |
| new_state_dict[k] = state_dict[k] |
| else: |
| print("==> WARNING: Skipping %s" % k) |
| print( |
| "%s out of %s keys are matched" |
| % (len(new_state_dict.keys()), len(state_dict.keys())) |
| ) |
| return new_state_dict |
|
|
|
|
| def get_param_num(model): |
| num_param = sum(param.numel() for param in model.parameters()) |
| return num_param |
|
|
|
|
| def torch_version_orig_mod_remove(state_dict): |
| new_state_dict = {} |
| new_state_dict["generator"] = {} |
| for key in state_dict["generator"].keys(): |
| if "_orig_mod." in key: |
| new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[ |
| "generator" |
| ][key] |
| else: |
| new_state_dict["generator"][key] = state_dict["generator"][key] |
| return new_state_dict |
|
|
|
|
| def get_vocoder(config, device, mel_bins): |
| name = "HiFi-GAN" |
| speaker = "" |
| if name == "MelGAN": |
| if speaker == "LJSpeech": |
| vocoder = torch.hub.load( |
| "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" |
| ) |
| elif speaker == "universal": |
| vocoder = torch.hub.load( |
| "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" |
| ) |
| vocoder.mel2wav.eval() |
| vocoder.mel2wav.to(device) |
| elif name == "HiFi-GAN": |
| if mel_bins == 64: |
| config = get_vocoder_config() |
| config = hifigan.AttrDict(config) |
| vocoder = hifigan.Generator_old(config) |
| |
| |
| |
| |
| |
| vocoder.eval() |
| vocoder.remove_weight_norm() |
| vocoder.to(device) |
| else: |
| config = get_vocoder_config_48k() |
| config = hifigan.AttrDict(config) |
| vocoder = hifigan.Generator_old(config) |
| |
| |
| |
| |
| |
| vocoder.eval() |
| vocoder.remove_weight_norm() |
| vocoder.to(device) |
| return vocoder |
|
|
|
|
| def vocoder_infer(mels, vocoder, lengths=None): |
| with torch.no_grad(): |
| wavs = vocoder(mels).squeeze(1) |
|
|
| wavs = (wavs.cpu().numpy() * 32768).astype("int16") |
|
|
| if lengths is not None: |
| wavs = wavs[:, :lengths] |
|
|
| |
|
|
| |
| |
| |
|
|
| return wavs |
|
|