ColabWan / postprocessing /seedvc /inference_realtime.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
13.9 kB
import os
import sys
from dotenv import load_dotenv
os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
import shutil
import multiprocessing
import warnings
import yaml
warnings.simplefilter('ignore')
from tqdm import tqdm
from .modules.commons import *
import librosa
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from .hf_utils import load_custom_model_from_hf
import os
import sys
import torch
from .modules.commons import str2bool
# Load model and configuration
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
flag_vc = False
prompt_condition, mel2, style2 = None, None, None
reference_wav_name = ""
prompt_len = 3 # in seconds
ce_dit_difference = 2.0 # 2 seconds
fp16 = False
@torch.no_grad()
def custom_infer(model_set,
reference_wav,
new_reference_wav_name,
input_wav_res,
block_frame_16k,
skip_head,
skip_tail,
return_length,
diffusion_steps,
inference_cfg_rate,
max_prompt_length,
cd_difference=2.0,
):
global prompt_condition, mel2, style2
global reference_wav_name
global prompt_len
global ce_dit_difference
(
model,
semantic_fn,
f0_fn,
vocoder_fn,
campplus_model,
to_mel,
mel_fn_args,
) = model_set
sr = mel_fn_args["sampling_rate"]
hop_length = mel_fn_args["hop_size"]
if ce_dit_difference != cd_difference:
ce_dit_difference = cd_difference
print(f"Setting ce_dit_difference to {cd_difference} seconds.")
if prompt_condition is None or reference_wav_name != new_reference_wav_name or prompt_len != max_prompt_length:
prompt_len = max_prompt_length
print(f"Setting max prompt length to {max_prompt_length} seconds.")
reference_wav = reference_wav[:int(sr * prompt_len)]
reference_wav_tensor = torch.from_numpy(reference_wav).to(device)
ori_waves_16k = torchaudio.functional.resample(reference_wav_tensor, sr, 16000)
S_ori = semantic_fn(ori_waves_16k.unsqueeze(0))
feat2 = torchaudio.compliance.kaldi.fbank(
ori_waves_16k.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000
)
feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
style2 = campplus_model(feat2.unsqueeze(0))
mel2 = to_mel(reference_wav_tensor.unsqueeze(0))
target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
prompt_condition = model.length_regulator(
S_ori, ylens=target2_lengths, n_quantizers=3, f0=None
)[0]
reference_wav_name = new_reference_wav_name
converted_waves_16k = input_wav_res
if device.type == "mps":
start_event = torch.mps.event.Event(enable_timing=True)
end_event = torch.mps.event.Event(enable_timing=True)
torch.mps.synchronize()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
S_alt = semantic_fn(converted_waves_16k.unsqueeze(0))
end_event.record()
if device.type == "mps":
torch.mps.synchronize() # MPS - Wait for the events to be recorded!
else:
torch.cuda.synchronize() # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Time taken for semantic_fn: {elapsed_time_ms}ms")
ce_dit_frame_difference = int(ce_dit_difference * 50)
S_alt = S_alt[:, ce_dit_frame_difference:]
target_lengths = torch.LongTensor([(skip_head + return_length + skip_tail - ce_dit_frame_difference) / 50 * sr // hop_length]).to(S_alt.device)
print(f"target_lengths: {target_lengths}")
cond = model.length_regulator(
S_alt, ylens=target_lengths , n_quantizers=3, f0=None
)[0]
cat_condition = torch.cat([prompt_condition, cond], dim=1)
with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32):
vc_target = model.cfm.inference(
cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2,
style2,
None,
n_timesteps=diffusion_steps,
inference_cfg_rate=inference_cfg_rate,
)
vc_target = vc_target[:, :, mel2.size(-1) :]
print(f"vc_target.shape: {vc_target.shape}")
vc_wave = vocoder_fn(vc_target).squeeze()
output_len = return_length * sr // 50
tail_len = skip_tail * sr // 50
output = vc_wave[-output_len - tail_len: -tail_len]
return output
def load_models(args):
global fp16
fp16 = args.fp16
print(f"Using fp16: {fp16}")
f0_fn = None
if args.checkpoint is None or args.checkpoint == "":
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
"DiT_uvit_tat_xlsr_ema.pth",
"config_dit_mel_seed_uvit_xlsr_tiny.yml")
else:
dit_checkpoint_path = args.checkpoint
dit_config_path = args.config_path
config = yaml.safe_load(open(dit_config_path, "r"))
model_params = recursive_munch(config["model_params"])
model_params.dit_type = 'DiT'
model = build_model(model_params, stage="DiT")
hop_length = config["preprocess_params"]["spect_params"]["hop_length"]
sr = config["preprocess_params"]["sr"]
# Load checkpoints
model, _, _, _ = load_checkpoint(
model,
None,
dit_checkpoint_path,
load_only_params=True,
ignore_modules=[],
is_distributed=False,
)
for key in model:
model[key].eval()
model[key].to(device)
model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
# Load additional modules
from .modules.campplus.DTDNN import CAMPPlus
campplus_ckpt_path = load_custom_model_from_hf(
"funasr/campplus", "campplus_cn_common.bin", config_filename=None
)
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
campplus_model.eval()
campplus_model.to(device)
vocoder_type = model_params.vocoder.type
if vocoder_type == 'bigvgan':
from .modules.bigvgan import bigvgan
bigvgan_name = model_params.vocoder.name
bigvgan_model = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=False)
# remove weight norm in the model and set to eval mode
bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval().to(device)
vocoder_fn = bigvgan_model
elif vocoder_type == 'hifigan':
from .modules.hifigan.generator import HiFTGenerator
from .modules.hifigan.f0_predictor import ConvRNNF0Predictor
from ._paths import resolve_path
hift_config = yaml.safe_load(open(resolve_path('configs/hifigan.yml'), 'r'))
hift_gen = HiFTGenerator(**hift_config['hift'], f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor']))
hift_path = load_custom_model_from_hf("FunAudioLLM/CosyVoice-300M", 'hift.pt', None)
hift_gen.load_state_dict(torch.load(hift_path, map_location='cpu'))
hift_gen.eval()
hift_gen.to(device)
vocoder_fn = hift_gen
elif vocoder_type == "vocos":
vocos_config = yaml.safe_load(open(model_params.vocoder.vocos.config, 'r'))
vocos_path = model_params.vocoder.vocos.path
vocos_model_params = recursive_munch(vocos_config['model_params'])
vocos = build_model(vocos_model_params, stage='mel_vocos')
vocos_checkpoint_path = vocos_path
vocos, _, _, _ = load_checkpoint(vocos, None, vocos_checkpoint_path,
load_only_params=True, ignore_modules=[], is_distributed=False)
_ = [vocos[key].eval().to(device) for key in vocos]
_ = [vocos[key].to(device) for key in vocos]
total_params = sum(sum(p.numel() for p in vocos[key].parameters() if p.requires_grad) for key in vocos.keys())
print(f"Vocoder model total parameters: {total_params / 1_000_000:.2f}M")
vocoder_fn = vocos.decoder
else:
raise ValueError(f"Unknown vocoder type: {vocoder_type}")
speech_tokenizer_type = model_params.speech_tokenizer.type
if speech_tokenizer_type == 'whisper':
# whisper
from transformers import AutoFeatureExtractor, WhisperModel
whisper_name = model_params.speech_tokenizer.name
whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
del whisper_model.decoder
whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
def semantic_fn(waves_16k):
ori_inputs = whisper_feature_extractor([waves_16k.squeeze(0).cpu().numpy()],
return_tensors="pt",
return_attention_mask=True)
ori_input_features = whisper_model._mask_input_features(
ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
with torch.no_grad():
ori_outputs = whisper_model.encoder(
ori_input_features.to(whisper_model.encoder.dtype),
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
S_ori = ori_outputs.last_hidden_state.to(torch.float32)
S_ori = S_ori[:, :waves_16k.size(-1) // 320 + 1]
return S_ori
elif speech_tokenizer_type == 'cnhubert':
from transformers import (
Wav2Vec2FeatureExtractor,
HubertModel,
)
hubert_model_name = config['model_params']['speech_tokenizer']['name']
hubert_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model_name)
hubert_model = HubertModel.from_pretrained(hubert_model_name)
hubert_model = hubert_model.to(device)
hubert_model = hubert_model.eval()
hubert_model = hubert_model.half()
def semantic_fn(waves_16k):
ori_waves_16k_input_list = [
waves_16k[bib].cpu().numpy()
for bib in range(len(waves_16k))
]
ori_inputs = hubert_feature_extractor(ori_waves_16k_input_list,
return_tensors="pt",
return_attention_mask=True,
padding=True,
sampling_rate=16000).to(device)
with torch.no_grad():
ori_outputs = hubert_model(
ori_inputs.input_values.half(),
)
S_ori = ori_outputs.last_hidden_state.float()
return S_ori
elif speech_tokenizer_type == 'xlsr':
from transformers import (
Wav2Vec2FeatureExtractor,
Wav2Vec2Model,
)
model_name = config['model_params']['speech_tokenizer']['name']
output_layer = config['model_params']['speech_tokenizer']['output_layer']
wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
wav2vec_model = Wav2Vec2Model.from_pretrained(model_name)
wav2vec_model.encoder.layers = wav2vec_model.encoder.layers[:output_layer]
wav2vec_model = wav2vec_model.to(device)
wav2vec_model = wav2vec_model.eval()
wav2vec_model = wav2vec_model.half()
def semantic_fn(waves_16k):
ori_waves_16k_input_list = [
waves_16k[bib].cpu().numpy()
for bib in range(len(waves_16k))
]
ori_inputs = wav2vec_feature_extractor(ori_waves_16k_input_list,
return_tensors="pt",
return_attention_mask=True,
padding=True,
sampling_rate=16000).to(device)
with torch.no_grad():
ori_outputs = wav2vec_model(
ori_inputs.input_values.half(),
)
S_ori = ori_outputs.last_hidden_state.float()
return S_ori
else:
raise ValueError(f"Unknown speech tokenizer type: {speech_tokenizer_type}")
# Generate mel spectrograms
mel_fn_args = {
"n_fft": config['preprocess_params']['spect_params']['n_fft'],
"win_size": config['preprocess_params']['spect_params']['win_length'],
"hop_size": config['preprocess_params']['spect_params']['hop_length'],
"num_mels": config['preprocess_params']['spect_params']['n_mels'],
"sampling_rate": sr,
"fmin": config['preprocess_params']['spect_params'].get('fmin', 0),
"fmax": None if config['preprocess_params']['spect_params'].get('fmax', "None") == "None" else 8000,
"center": False
}
from .modules.audio import mel_spectrogram
to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
return (
model,
semantic_fn,
f0_fn,
vocoder_fn,
campplus_model,
to_mel,
mel_fn_args,
)