| |
|
| |
|
| | import torch |
| | import torchaudio |
| | import librosa |
| | import yaml |
| | import numpy as np |
| | import soundfile as sf |
| | import phonemizer |
| | from munch import Munch |
| | import os |
| | import time |
| |
|
| | |
| | from models import * |
| | from utils import * |
| | from text_utils import TextCleaner |
| | from Utils.PLBERT.util import load_plbert |
| | from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule |
| |
|
| | |
| | CONFIG_PATH = "/workspace/trainTTS/StyleTTS2_custom/Configs/config_ft.yml" |
| | MODEL_PATH = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/model_iter_00032000.pth" |
| | REF_AUDIO_PATH = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai.wav" |
| | OUTPUT_WAV = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai_cut.wav" |
| |
|
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| |
|
| | class StyleTTS2Inference: |
| | def __init__(self, config_path, model_path, device=DEVICE): |
| | self.device = device |
| | self.config = yaml.safe_load(open(config_path)) |
| | |
| | |
| | self.phonemizer = phonemizer.backend.EspeakBackend( |
| | language='vi', preserve_punctuation=True, with_stress=True |
| | ) |
| | self.text_cleaner = TextCleaner() |
| | |
| | |
| | |
| | |
| | text_aligner = load_ASR_models(self.config['ASR_path'], self.config['ASR_config']) |
| | pitch_extractor = load_F0_models(self.config['F0_path']) |
| | plbert = load_plbert(self.config['PLBERT_dir']) |
| | |
| | |
| | model_params = recursive_munch(self.config['model_params']) |
| | self.model = build_model(model_params, text_aligner, pitch_extractor, plbert) |
| | |
| | |
| | print(f"Loading model from: {model_path}") |
| | params = torch.load(model_path, map_location='cpu') |
| | |
| | |
| | if 'net' in params: |
| | params = params['net'] |
| | |
| | for key in self.model: |
| | |
| | if key not in params: |
| | print(f"⚠️ Bỏ qua module '{key}' (không tìm thấy trong checkpoint - OK với model inference)") |
| | continue |
| | |
| | |
| | state_dict = params[key] |
| | new_state_dict = {} |
| | |
| | for k, v in state_dict.items(): |
| | if k.startswith("module."): |
| | new_state_dict[k[len("module."):]] = v |
| | else: |
| | new_state_dict[k] = v |
| | |
| | self.model[key].load_state_dict(new_state_dict, strict=True) |
| | self.model[key].eval().to(self.device) |
| | print(f"✅ Loaded module: {key}") |
| |
|
| | |
| | self.sampler = DiffusionSampler( |
| | self.model.diffusion.diffusion, |
| | sampler=ADPM2Sampler(), |
| | sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), |
| | clamp=False |
| | ) |
| | print("Model initialization complete.\n") |
| |
|
| | def preprocess_audio(self, audio_path): |
| | """Chuyển đổi audio reference thành Style Vector""" |
| | wave, sr = librosa.load(audio_path, sr=24000) |
| | audio, _ = librosa.effects.trim(wave, top_db=30) |
| | |
| | to_mel = torchaudio.transforms.MelSpectrogram( |
| | n_mels=80, n_fft=2048, win_length=1200, hop_length=300 |
| | ) |
| | mel = to_mel(torch.from_numpy(audio).float()) |
| | mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4 |
| | mel = mel.to(self.device) |
| | |
| | with torch.no_grad(): |
| | ref_s = self.model.style_encoder(mel.unsqueeze(1)) |
| | ref_p = self.model.predictor_encoder(mel.unsqueeze(1)) |
| | ref_style = torch.cat([ref_s, ref_p], dim=1) |
| | |
| | return ref_style |
| |
|
| | def preprocess_text(self, text): |
| | """Phonemize và Tokenize văn bản""" |
| | text = text.strip() |
| | if not text: return None |
| | |
| | ps = self.phonemizer.phonemize([text])[0] |
| | tokens = torch.LongTensor(self.text_cleaner(ps)).to(self.device).unsqueeze(0) |
| | |
| | tokens = torch.cat([torch.LongTensor([0]).to(self.device).unsqueeze(0), tokens], dim=-1) |
| | return tokens |
| |
|
| | def inference(self, text, ref_style, diffusion_steps=5, alpha=0.3, beta=0.7): |
| | """Hàm suy luận cốt lõi""" |
| | tokens = self.preprocess_text(text) |
| | if tokens is None: return None |
| | |
| | input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device) |
| | text_mask = length_to_mask(input_lengths).to(self.device) |
| |
|
| | with torch.no_grad(): |
| | |
| | t_en = self.model.text_encoder(tokens, input_lengths, text_mask) |
| | bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int()) |
| | d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) |
| |
|
| | |
| | s_pred = self.sampler( |
| | noise=torch.randn((1, 256)).unsqueeze(1).to(self.device), |
| | embedding=bert_dur, |
| | features=ref_style, |
| | num_steps=diffusion_steps |
| | ).squeeze(1) |
| |
|
| | |
| | |
| | |
| | s = s_pred[:, 128:] * alpha + ref_style[:, 128:] * beta |
| | ref = s_pred[:, :128] * alpha + ref_style[:, :128] * beta |
| |
|
| | |
| | d = self.model.predictor.text_encoder(d_en, s, input_lengths, text_mask) |
| | x, _ = self.model.predictor.lstm(d) |
| | duration = torch.sigmoid(self.model.predictor.duration_proj(x)).sum(axis=-1) |
| | pred_dur = torch.round(duration.squeeze()).clamp(min=1) |
| |
|
| | |
| | pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) |
| | c_frame = 0 |
| | for i in range(pred_aln_trg.size(0)): |
| | pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 |
| | c_frame += int(pred_dur[i].data) |
| |
|
| | |
| | en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device)) |
| | F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s) |
| | asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device)) |
| | |
| | out = self.model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) |
| |
|
| | return out.squeeze().cpu().numpy()[..., :-50] |
| |
|
| | def generate_long_text(self, text, ref_audio_path): |
| | """Xử lý văn bản dài bằng cách tách câu""" |
| | print(f"Processing audio ref: {ref_audio_path}") |
| | ref_style = self.preprocess_audio(ref_audio_path) |
| | |
| | |
| | sentences = text.split('.') |
| | wavs = [] |
| | |
| | start_time = time.time() |
| | print("Start synthesizing...") |
| | |
| | for sent in sentences: |
| | if len(sent.strip()) == 0: continue |
| | |
| | |
| | if not sent.strip().endswith('.'): sent += '.' |
| | |
| | wav = self.inference(sent, ref_style) |
| | if wav is not None: |
| | wavs.append(wav) |
| | |
| | silence = np.zeros(int(24000 * 0.1)) |
| | wavs.append(silence) |
| | |
| | full_wav = np.concatenate(wavs) |
| | print(f"Done! Total time: {time.time() - start_time:.2f}s") |
| | return full_wav |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | tts = StyleTTS2Inference(CONFIG_PATH, MODEL_PATH) |
| |
|
| | |
| | list_texts = ["xin chào việt nam, hôm nay trời rất đẹp"] |
| |
|
| | full_audio = [] |
| | |
| | |
| | for text in list_texts: |
| | audio_segment = tts.generate_long_text(text, REF_AUDIO_PATH) |
| | full_audio.append(audio_segment) |
| | |
| | full_audio.append(np.zeros(int(24000 * 0.5))) |
| |
|
| | |
| | final_wav = np.concatenate(full_audio) |
| | sf.write(OUTPUT_WAV, final_wav, 24000) |
| | print(f"File saved to: {OUTPUT_WAV}") |