|
|
|
|
|
|
| import torch |
| import torchaudio |
| import librosa |
| import yaml |
| |
| import phonemizer |
| import time |
|
|
| torch.set_num_threads(8) |
|
|
| |
| device = "cuda" |
| global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, language_switch='remove-flags') |
|
|
| |
| |
| config = yaml.safe_load(open("/workspace/trainTTS/StyleTTS2_custom/Configs/config_ft.yml")) |
| from models import * |
| from utils import * |
| from text_utils import TextCleaner |
|
|
| textclenaer = TextCleaner() |
| text_aligner = load_ASR_models(config['ASR_path'], config['ASR_config']) |
| pitch_extractor = load_F0_models(config['F0_path']) |
| from Utils.PLBERT.util import load_plbert |
| plbert = load_plbert(config['PLBERT_dir']) |
|
|
| model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert) |
| |
| |
| params = torch.load("/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/ver2/best_model.pth", map_location='cuda')['net'] |
|
|
| |
| for key in model: |
| state_dict = params[key] |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if k.startswith("module."): |
| new_state_dict[k[len("module."):]] = v |
| |
| |
| else: |
| new_state_dict[k] = v |
| |
| |
| model[key].load_state_dict(new_state_dict, strict=True) |
| model[key].eval().to(device) |
|
|
| from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule |
| sampler = DiffusionSampler(model.diffusion.diffusion, sampler=ADPM2Sampler(), |
| sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), clamp=False) |
|
|
| |
| def clone_voice(text, reference_audio_path): |
| |
| wave, sr = librosa.load(reference_audio_path, sr=24000) |
| audio, _ = librosa.effects.trim(wave, top_db=30) |
| |
| to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300) |
| mel = to_mel(torch.from_numpy(audio).float()) |
| mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4 |
| mel = mel.to(device) |
| |
| with torch.no_grad(): |
| ref_s = model.style_encoder(mel.unsqueeze(1)) |
| ref_p = model.predictor_encoder(mel.unsqueeze(1)) |
| ref_style = torch.cat([ref_s, ref_p], dim=1) |
| |
| |
| ps = global_phonemizer.phonemize([text.strip()])[0] |
| ps = ps.replace("t̪", "vhv.vn").replace("t", "tʰ").replace("vhv.vn", "t") |
| |
| tokens = torch.LongTensor(textclenaer(ps)).to(device).unsqueeze(0) |
| tokens = torch.cat([torch.LongTensor([0]).to(device).unsqueeze(0), tokens], dim=-1) |
| |
| input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) |
| text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths) |
| text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(device) |
| |
| with torch.no_grad(): |
| t_en = model.text_encoder(tokens, input_lengths, text_mask) |
| bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) |
| d_en = model.bert_encoder(bert_dur).transpose(-1, -2) |
| |
| s_pred = sampler(noise=torch.randn((1, 256)).unsqueeze(1).to(device), |
| embedding=bert_dur, features=ref_style, num_steps=5).squeeze(1) |
| |
| s = 0.7 * s_pred[:, 128:] + 0.3 * ref_style[:, 128:] |
| ref = 0.3 * s_pred[:, :128] + 0.7 * ref_style[:, :128] |
| |
| d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) |
| x, _ = model.predictor.lstm(d) |
| duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1) |
| pred_dur = torch.round(duration.squeeze()).clamp(min=1) |
| |
| pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) |
| c_frame = 0 |
| for i in range(pred_aln_trg.size(0)): |
| pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 |
| c_frame += int(pred_dur[i].data) |
| |
| en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) |
| F0_pred, N_pred = model.predictor.F0Ntrain(en, s) |
| asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) |
| out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) |
| |
| return out.squeeze().cpu().numpy()[..., :-50] |
|
|
|
|
| list_texts = ["so now I want to tell you something about myself, do you know about my hometown, this is a beautiful place"] |
|
|
| for text in list_texts: |
| st = time.time() |
| |
| sens = text.split('.') |
| wavs = [] |
| for sen in sens: |
| if sen: |
| wav = clone_voice(sen, "/workspace/trainTTS/StyleTTS2_custom/sangnq_original.wav") |
| wavs.append(wav) |
|
|
| all_wav = np.concatenate(wavs) |
| print('Time cloning voice = ', time.time()-st) |
|
|
| |
| import soundfile as sf |
| sf.write('/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/sangnq_en_us_best.wav', all_wav, 24000) |