Spaces:
Running
Running
| import torch | |
| import torchaudio | |
| import librosa | |
| import yaml | |
| # from nltk.tokenize import word_tokenize | |
| import phonemizer | |
| import time | |
| torch.set_num_threads(8) | |
| # Setup | |
| device = "cuda"#'cuda' if torch.cuda.is_available() else 'cpu' | |
| global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, language_switch='remove-flags') | |
| # Load model (1 lần duy nhất) | |
| # config = yaml.safe_load(open("/home/general/TTS/train_model/StyleTTS2/Configs/config_ft_hieuld.yml")) | |
| config = yaml.safe_load(open("/workspace/trainTTS/StyleTTS2_custom/Configs/config_ft.yml")) | |
| from models import * | |
| from utils import * | |
| from text_utils import TextCleaner | |
| textclenaer = TextCleaner() | |
| text_aligner = load_ASR_models(config['ASR_path'], config['ASR_config']) | |
| pitch_extractor = load_F0_models(config['F0_path']) | |
| from Utils.PLBERT.util import load_plbert | |
| plbert = load_plbert(config['PLBERT_dir']) | |
| model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert) | |
| # params = torch.load("/home/general/TTS/train_model/StyleTTS2/merged_model_last.pth", map_location='cuda')['net'] | |
| # params = torch.load("/home/general/TTS/train_model/StyleTTS2/Demo/merged_model_now.pth", map_location='cuda')['net'] | |
| params = torch.load("/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/ver2/best_model.pth", map_location='cuda')['net'] | |
| # params = torch.load("/home/general/TTS/train_model/StyleTTS2/merged_model_last.pth", map_location='cuda')['net'] | |
| for key in model: | |
| state_dict = params[key] | |
| # remove "module." prefix nếu có | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("module."): | |
| new_state_dict[k[len("module."):]] = v | |
| # elif k.startswith("shared."): | |
| # new_state_dict[k[len("shared."):]] = v | |
| else: | |
| new_state_dict[k] = v | |
| # print(new_state_dict) | |
| # print(len(new_state_dict)) | |
| model[key].load_state_dict(new_state_dict, strict=True) | |
| model[key].eval().to(device) | |
| from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule | |
| sampler = DiffusionSampler(model.diffusion.diffusion, sampler=ADPM2Sampler(), | |
| sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), clamp=False) | |
| # Hàm clone voice | |
| def clone_voice(text, reference_audio_path): | |
| # Extract style từ reference | |
| wave, sr = librosa.load(reference_audio_path, sr=24000) | |
| audio, _ = librosa.effects.trim(wave, top_db=30) | |
| to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300) | |
| mel = to_mel(torch.from_numpy(audio).float()) | |
| mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4 | |
| mel = mel.to(device) | |
| with torch.no_grad(): | |
| ref_s = model.style_encoder(mel.unsqueeze(1)) | |
| ref_p = model.predictor_encoder(mel.unsqueeze(1)) | |
| ref_style = torch.cat([ref_s, ref_p], dim=1) | |
| # Synthesize | |
| ps = global_phonemizer.phonemize([text.strip()])[0] | |
| ps = ps.replace("t̪", "vhv.vn").replace("t", "tʰ").replace("vhv.vn", "t") | |
| # ps = ' '.join(word_tokenize(ps[0])) | |
| tokens = torch.LongTensor(textclenaer(ps)).to(device).unsqueeze(0) | |
| tokens = torch.cat([torch.LongTensor([0]).to(device).unsqueeze(0), tokens], dim=-1) | |
| input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) | |
| text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths) | |
| text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(device) | |
| with torch.no_grad(): | |
| t_en = model.text_encoder(tokens, input_lengths, text_mask) | |
| bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) | |
| d_en = model.bert_encoder(bert_dur).transpose(-1, -2) | |
| s_pred = sampler(noise=torch.randn((1, 256)).unsqueeze(1).to(device), | |
| embedding=bert_dur, features=ref_style, num_steps=5).squeeze(1) | |
| s = 0.7 * s_pred[:, 128:] + 0.3 * ref_style[:, 128:] | |
| ref = 0.3 * s_pred[:, :128] + 0.7 * ref_style[:, :128] | |
| d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) | |
| x, _ = model.predictor.lstm(d) | |
| duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1) | |
| pred_dur = torch.round(duration.squeeze()).clamp(min=1) | |
| pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) | |
| c_frame = 0 | |
| for i in range(pred_aln_trg.size(0)): | |
| pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 | |
| c_frame += int(pred_dur[i].data) | |
| en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) | |
| F0_pred, N_pred = model.predictor.F0Ntrain(en, s) | |
| asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) | |
| out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) | |
| return out.squeeze().cpu().numpy()[..., :-50] | |
| list_texts = ["so now I want to tell you something about myself, do you know about my hometown, this is a beautiful place"] | |
| for text in list_texts: | |
| st = time.time() | |
| # wav = clone_voice(text, "/home/general/TTS/train_model/audio/donal_trump.wav") | |
| sens = text.split('.') | |
| wavs = [] | |
| for sen in sens: | |
| if sen: | |
| wav = clone_voice(sen, "/workspace/trainTTS/StyleTTS2_custom/sangnq_original.wav") | |
| wavs.append(wav) | |
| all_wav = np.concatenate(wavs) | |
| print('Time cloning voice = ', time.time()-st) | |
| # Save hoặc play | |
| import soundfile as sf | |
| sf.write('/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/sangnq_en_us_best.wav', all_wav, 24000) |