styletts2-ver2 / inference_vi.py
hieuducle's picture
Upload full StyleTTS2_custom folder
1b242be verified
import torch
import torchaudio
import librosa
import yaml
# from nltk.tokenize import word_tokenize
import phonemizer
import time
torch.set_num_threads(8)
# Setup
device = "cuda"#'cuda' if torch.cuda.is_available() else 'cpu'
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, language_switch='remove-flags')
# Load model (1 lần duy nhất)
# config = yaml.safe_load(open("/home/general/TTS/train_model/StyleTTS2/Configs/config_ft_hieuld.yml"))
config = yaml.safe_load(open("/workspace/trainTTS/StyleTTS2_custom/Configs/config_ft.yml"))
from models import *
from utils import *
from text_utils import TextCleaner
textclenaer = TextCleaner()
text_aligner = load_ASR_models(config['ASR_path'], config['ASR_config'])
pitch_extractor = load_F0_models(config['F0_path'])
from Utils.PLBERT.util import load_plbert
plbert = load_plbert(config['PLBERT_dir'])
model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)
# params = torch.load("/home/general/TTS/train_model/StyleTTS2/merged_model_last.pth", map_location='cuda')['net']
# params = torch.load("/home/general/TTS/train_model/StyleTTS2/Demo/merged_model_now.pth", map_location='cuda')['net']
params = torch.load("/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/ver2/best_model.pth", map_location='cuda')['net']
# params = torch.load("/home/general/TTS/train_model/StyleTTS2/merged_model_last.pth", map_location='cuda')['net']
for key in model:
state_dict = params[key]
# remove "module." prefix nếu có
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("module."):
new_state_dict[k[len("module."):]] = v
# elif k.startswith("shared."):
# new_state_dict[k[len("shared."):]] = v
else:
new_state_dict[k] = v
# print(new_state_dict)
# print(len(new_state_dict))
model[key].load_state_dict(new_state_dict, strict=True)
model[key].eval().to(device)
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
sampler = DiffusionSampler(model.diffusion.diffusion, sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), clamp=False)
# Hàm clone voice
def clone_voice(text, reference_audio_path):
# Extract style từ reference
wave, sr = librosa.load(reference_audio_path, sr=24000)
audio, _ = librosa.effects.trim(wave, top_db=30)
to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mel = to_mel(torch.from_numpy(audio).float())
mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4
mel = mel.to(device)
with torch.no_grad():
ref_s = model.style_encoder(mel.unsqueeze(1))
ref_p = model.predictor_encoder(mel.unsqueeze(1))
ref_style = torch.cat([ref_s, ref_p], dim=1)
# Synthesize
ps = global_phonemizer.phonemize([text.strip()])[0]
ps = ps.replace("t̪", "vhv.vn").replace("t", "tʰ").replace("vhv.vn", "t")
# ps = ' '.join(word_tokenize(ps[0]))
tokens = torch.LongTensor(textclenaer(ps)).to(device).unsqueeze(0)
tokens = torch.cat([torch.LongTensor([0]).to(device).unsqueeze(0), tokens], dim=-1)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(device)
with torch.no_grad():
t_en = model.text_encoder(tokens, input_lengths, text_mask)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s_pred = sampler(noise=torch.randn((1, 256)).unsqueeze(1).to(device),
embedding=bert_dur, features=ref_style, num_steps=5).squeeze(1)
s = 0.7 * s_pred[:, 128:] + 0.3 * ref_style[:, 128:]
ref = 0.3 * s_pred[:, :128] + 0.7 * ref_style[:, :128]
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1)
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
c_frame += int(pred_dur[i].data)
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
return out.squeeze().cpu().numpy()[..., :-50]
list_texts = ["so now I want to tell you something about myself, do you know about my hometown, this is a beautiful place"]
for text in list_texts:
st = time.time()
# wav = clone_voice(text, "/home/general/TTS/train_model/audio/donal_trump.wav")
sens = text.split('.')
wavs = []
for sen in sens:
if sen:
wav = clone_voice(sen, "/workspace/trainTTS/StyleTTS2_custom/sangnq_original.wav")
wavs.append(wav)
all_wav = np.concatenate(wavs)
print('Time cloning voice = ', time.time()-st)
# Save hoặc play
import soundfile as sf
sf.write('/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/sangnq_en_us_best.wav', all_wav, 24000)