| import os |
| import sys |
| from dotenv import load_dotenv |
|
|
| os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache' |
|
|
| import shutil |
|
|
| import multiprocessing |
| import warnings |
| import yaml |
|
|
| warnings.simplefilter('ignore') |
|
|
| from tqdm import tqdm |
| from .modules.commons import * |
| import librosa |
| import torchaudio |
| import torchaudio.compliance.kaldi as kaldi |
|
|
| from .hf_utils import load_custom_model_from_hf |
|
|
| import os |
| import sys |
| import torch |
| from .modules.commons import str2bool |
|
|
| |
| |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| elif torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| else: |
| device = torch.device("cpu") |
|
|
| flag_vc = False |
|
|
| prompt_condition, mel2, style2 = None, None, None |
| reference_wav_name = "" |
|
|
| prompt_len = 3 |
| ce_dit_difference = 2.0 |
| fp16 = False |
| @torch.no_grad() |
| def custom_infer(model_set, |
| reference_wav, |
| new_reference_wav_name, |
| input_wav_res, |
| block_frame_16k, |
| skip_head, |
| skip_tail, |
| return_length, |
| diffusion_steps, |
| inference_cfg_rate, |
| max_prompt_length, |
| cd_difference=2.0, |
| ): |
| global prompt_condition, mel2, style2 |
| global reference_wav_name |
| global prompt_len |
| global ce_dit_difference |
| ( |
| model, |
| semantic_fn, |
| f0_fn, |
| vocoder_fn, |
| campplus_model, |
| to_mel, |
| mel_fn_args, |
| ) = model_set |
| sr = mel_fn_args["sampling_rate"] |
| hop_length = mel_fn_args["hop_size"] |
| if ce_dit_difference != cd_difference: |
| ce_dit_difference = cd_difference |
| print(f"Setting ce_dit_difference to {cd_difference} seconds.") |
| if prompt_condition is None or reference_wav_name != new_reference_wav_name or prompt_len != max_prompt_length: |
| prompt_len = max_prompt_length |
| print(f"Setting max prompt length to {max_prompt_length} seconds.") |
| reference_wav = reference_wav[:int(sr * prompt_len)] |
| reference_wav_tensor = torch.from_numpy(reference_wav).to(device) |
|
|
| ori_waves_16k = torchaudio.functional.resample(reference_wav_tensor, sr, 16000) |
| S_ori = semantic_fn(ori_waves_16k.unsqueeze(0)) |
| feat2 = torchaudio.compliance.kaldi.fbank( |
| ori_waves_16k.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000 |
| ) |
| feat2 = feat2 - feat2.mean(dim=0, keepdim=True) |
| style2 = campplus_model(feat2.unsqueeze(0)) |
|
|
| mel2 = to_mel(reference_wav_tensor.unsqueeze(0)) |
| target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device) |
| prompt_condition = model.length_regulator( |
| S_ori, ylens=target2_lengths, n_quantizers=3, f0=None |
| )[0] |
|
|
| reference_wav_name = new_reference_wav_name |
|
|
| converted_waves_16k = input_wav_res |
| if device.type == "mps": |
| start_event = torch.mps.event.Event(enable_timing=True) |
| end_event = torch.mps.event.Event(enable_timing=True) |
| torch.mps.synchronize() |
| else: |
| start_event = torch.cuda.Event(enable_timing=True) |
| end_event = torch.cuda.Event(enable_timing=True) |
| torch.cuda.synchronize() |
|
|
| start_event.record() |
| S_alt = semantic_fn(converted_waves_16k.unsqueeze(0)) |
| end_event.record() |
| if device.type == "mps": |
| torch.mps.synchronize() |
| else: |
| torch.cuda.synchronize() |
| elapsed_time_ms = start_event.elapsed_time(end_event) |
| print(f"Time taken for semantic_fn: {elapsed_time_ms}ms") |
|
|
| ce_dit_frame_difference = int(ce_dit_difference * 50) |
| S_alt = S_alt[:, ce_dit_frame_difference:] |
| target_lengths = torch.LongTensor([(skip_head + return_length + skip_tail - ce_dit_frame_difference) / 50 * sr // hop_length]).to(S_alt.device) |
| print(f"target_lengths: {target_lengths}") |
| cond = model.length_regulator( |
| S_alt, ylens=target_lengths , n_quantizers=3, f0=None |
| )[0] |
| cat_condition = torch.cat([prompt_condition, cond], dim=1) |
| with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32): |
| vc_target = model.cfm.inference( |
| cat_condition, |
| torch.LongTensor([cat_condition.size(1)]).to(mel2.device), |
| mel2, |
| style2, |
| None, |
| n_timesteps=diffusion_steps, |
| inference_cfg_rate=inference_cfg_rate, |
| ) |
| vc_target = vc_target[:, :, mel2.size(-1) :] |
| print(f"vc_target.shape: {vc_target.shape}") |
| vc_wave = vocoder_fn(vc_target).squeeze() |
| output_len = return_length * sr // 50 |
| tail_len = skip_tail * sr // 50 |
| output = vc_wave[-output_len - tail_len: -tail_len] |
|
|
| return output |
|
|
| def load_models(args): |
| global fp16 |
| fp16 = args.fp16 |
| print(f"Using fp16: {fp16}") |
| f0_fn = None |
| if args.checkpoint is None or args.checkpoint == "": |
| dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC", |
| "DiT_uvit_tat_xlsr_ema.pth", |
| "config_dit_mel_seed_uvit_xlsr_tiny.yml") |
| else: |
| dit_checkpoint_path = args.checkpoint |
| dit_config_path = args.config_path |
| config = yaml.safe_load(open(dit_config_path, "r")) |
| model_params = recursive_munch(config["model_params"]) |
| model_params.dit_type = 'DiT' |
| model = build_model(model_params, stage="DiT") |
| hop_length = config["preprocess_params"]["spect_params"]["hop_length"] |
| sr = config["preprocess_params"]["sr"] |
|
|
| |
| model, _, _, _ = load_checkpoint( |
| model, |
| None, |
| dit_checkpoint_path, |
| load_only_params=True, |
| ignore_modules=[], |
| is_distributed=False, |
| ) |
| for key in model: |
| model[key].eval() |
| model[key].to(device) |
| model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192) |
|
|
| |
| from .modules.campplus.DTDNN import CAMPPlus |
|
|
| campplus_ckpt_path = load_custom_model_from_hf( |
| "funasr/campplus", "campplus_cn_common.bin", config_filename=None |
| ) |
| campplus_model = CAMPPlus(feat_dim=80, embedding_size=192) |
| campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu")) |
| campplus_model.eval() |
| campplus_model.to(device) |
|
|
| vocoder_type = model_params.vocoder.type |
|
|
| if vocoder_type == 'bigvgan': |
| from .modules.bigvgan import bigvgan |
| bigvgan_name = model_params.vocoder.name |
| bigvgan_model = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=False) |
| |
| bigvgan_model.remove_weight_norm() |
| bigvgan_model = bigvgan_model.eval().to(device) |
| vocoder_fn = bigvgan_model |
| elif vocoder_type == 'hifigan': |
| from .modules.hifigan.generator import HiFTGenerator |
| from .modules.hifigan.f0_predictor import ConvRNNF0Predictor |
| from ._paths import resolve_path |
| hift_config = yaml.safe_load(open(resolve_path('configs/hifigan.yml'), 'r')) |
| hift_gen = HiFTGenerator(**hift_config['hift'], f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor'])) |
| hift_path = load_custom_model_from_hf("FunAudioLLM/CosyVoice-300M", 'hift.pt', None) |
| hift_gen.load_state_dict(torch.load(hift_path, map_location='cpu')) |
| hift_gen.eval() |
| hift_gen.to(device) |
| vocoder_fn = hift_gen |
| elif vocoder_type == "vocos": |
| vocos_config = yaml.safe_load(open(model_params.vocoder.vocos.config, 'r')) |
| vocos_path = model_params.vocoder.vocos.path |
| vocos_model_params = recursive_munch(vocos_config['model_params']) |
| vocos = build_model(vocos_model_params, stage='mel_vocos') |
| vocos_checkpoint_path = vocos_path |
| vocos, _, _, _ = load_checkpoint(vocos, None, vocos_checkpoint_path, |
| load_only_params=True, ignore_modules=[], is_distributed=False) |
| _ = [vocos[key].eval().to(device) for key in vocos] |
| _ = [vocos[key].to(device) for key in vocos] |
| total_params = sum(sum(p.numel() for p in vocos[key].parameters() if p.requires_grad) for key in vocos.keys()) |
| print(f"Vocoder model total parameters: {total_params / 1_000_000:.2f}M") |
| vocoder_fn = vocos.decoder |
| else: |
| raise ValueError(f"Unknown vocoder type: {vocoder_type}") |
|
|
| speech_tokenizer_type = model_params.speech_tokenizer.type |
| if speech_tokenizer_type == 'whisper': |
| |
| from transformers import AutoFeatureExtractor, WhisperModel |
| whisper_name = model_params.speech_tokenizer.name |
| whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device) |
| del whisper_model.decoder |
| whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name) |
|
|
| def semantic_fn(waves_16k): |
| ori_inputs = whisper_feature_extractor([waves_16k.squeeze(0).cpu().numpy()], |
| return_tensors="pt", |
| return_attention_mask=True) |
| ori_input_features = whisper_model._mask_input_features( |
| ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device) |
| with torch.no_grad(): |
| ori_outputs = whisper_model.encoder( |
| ori_input_features.to(whisper_model.encoder.dtype), |
| head_mask=None, |
| output_attentions=False, |
| output_hidden_states=False, |
| return_dict=True, |
| ) |
| S_ori = ori_outputs.last_hidden_state.to(torch.float32) |
| S_ori = S_ori[:, :waves_16k.size(-1) // 320 + 1] |
| return S_ori |
| elif speech_tokenizer_type == 'cnhubert': |
| from transformers import ( |
| Wav2Vec2FeatureExtractor, |
| HubertModel, |
| ) |
| hubert_model_name = config['model_params']['speech_tokenizer']['name'] |
| hubert_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model_name) |
| hubert_model = HubertModel.from_pretrained(hubert_model_name) |
| hubert_model = hubert_model.to(device) |
| hubert_model = hubert_model.eval() |
| hubert_model = hubert_model.half() |
|
|
| def semantic_fn(waves_16k): |
| ori_waves_16k_input_list = [ |
| waves_16k[bib].cpu().numpy() |
| for bib in range(len(waves_16k)) |
| ] |
| ori_inputs = hubert_feature_extractor(ori_waves_16k_input_list, |
| return_tensors="pt", |
| return_attention_mask=True, |
| padding=True, |
| sampling_rate=16000).to(device) |
| with torch.no_grad(): |
| ori_outputs = hubert_model( |
| ori_inputs.input_values.half(), |
| ) |
| S_ori = ori_outputs.last_hidden_state.float() |
| return S_ori |
| elif speech_tokenizer_type == 'xlsr': |
| from transformers import ( |
| Wav2Vec2FeatureExtractor, |
| Wav2Vec2Model, |
| ) |
| model_name = config['model_params']['speech_tokenizer']['name'] |
| output_layer = config['model_params']['speech_tokenizer']['output_layer'] |
| wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) |
| wav2vec_model = Wav2Vec2Model.from_pretrained(model_name) |
| wav2vec_model.encoder.layers = wav2vec_model.encoder.layers[:output_layer] |
| wav2vec_model = wav2vec_model.to(device) |
| wav2vec_model = wav2vec_model.eval() |
| wav2vec_model = wav2vec_model.half() |
|
|
| def semantic_fn(waves_16k): |
| ori_waves_16k_input_list = [ |
| waves_16k[bib].cpu().numpy() |
| for bib in range(len(waves_16k)) |
| ] |
| ori_inputs = wav2vec_feature_extractor(ori_waves_16k_input_list, |
| return_tensors="pt", |
| return_attention_mask=True, |
| padding=True, |
| sampling_rate=16000).to(device) |
| with torch.no_grad(): |
| ori_outputs = wav2vec_model( |
| ori_inputs.input_values.half(), |
| ) |
| S_ori = ori_outputs.last_hidden_state.float() |
| return S_ori |
| else: |
| raise ValueError(f"Unknown speech tokenizer type: {speech_tokenizer_type}") |
| |
| mel_fn_args = { |
| "n_fft": config['preprocess_params']['spect_params']['n_fft'], |
| "win_size": config['preprocess_params']['spect_params']['win_length'], |
| "hop_size": config['preprocess_params']['spect_params']['hop_length'], |
| "num_mels": config['preprocess_params']['spect_params']['n_mels'], |
| "sampling_rate": sr, |
| "fmin": config['preprocess_params']['spect_params'].get('fmin', 0), |
| "fmax": None if config['preprocess_params']['spect_params'].get('fmax', "None") == "None" else 8000, |
| "center": False |
| } |
| from .modules.audio import mel_spectrogram |
|
|
| to_mel = lambda x: mel_spectrogram(x, **mel_fn_args) |
|
|
| return ( |
| model, |
| semantic_fn, |
| f0_fn, |
| vocoder_fn, |
| campplus_model, |
| to_mel, |
| mel_fn_args, |
| ) |
|
|