import torch import torchaudio import librosa import yaml import numpy as np import soundfile as sf import phonemizer from munch import Munch import os import time # Import các module từ StyleTTS2 repo from models import * from utils import * from text_utils import TextCleaner from Utils.PLBERT.util import load_plbert from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule # ================= CẤU HÌNH ĐƯỜNG DẪN (SỬA TẠI ĐÂY) ================= CONFIG_PATH = "/workspace/trainTTS/StyleTTS2_custom/Configs/config_ft.yml" # Đường dẫn file config MODEL_PATH = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/model_iter_00032000.pth" # Đường dẫn model (đã clean hoặc chưa clean đều được) REF_AUDIO_PATH = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai.wav" # File giọng mẫu OUTPUT_WAV = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai_cut.wav" # File đầu ra DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ==================================================================== class StyleTTS2Inference: def __init__(self, config_path, model_path, device=DEVICE): self.device = device self.config = yaml.safe_load(open(config_path)) # 1. Khởi tạo các công cụ hỗ trợ self.phonemizer = phonemizer.backend.EspeakBackend( language='vi', preserve_punctuation=True, with_stress=True ) self.text_cleaner = TextCleaner() # 2. Load các thành phần cốt lõi (Structure) # Lưu ý: Vẫn cần load cấu trúc ASR/F0 để build_model không lỗi, # dù sau này không load trọng số vào chúng cũng không sao. text_aligner = load_ASR_models(self.config['ASR_path'], self.config['ASR_config']) pitch_extractor = load_F0_models(self.config['F0_path']) plbert = load_plbert(self.config['PLBERT_dir']) # 3. Xây dựng kiến trúc model (Vỏ rỗng) model_params = recursive_munch(self.config['model_params']) self.model = build_model(model_params, text_aligner, pitch_extractor, plbert) # 4. Load trọng số (State Dict) - PHẦN QUAN TRỌNG NHẤT print(f"Loading model from: {model_path}") params = torch.load(model_path, map_location='cpu') # Nếu file save có key 'net' thì lấy, không thì lấy trực tiếp (tùy cách save) if 'net' in params: params = params['net'] for key in self.model: # --- CHECK QUAN TRỌNG: Chỉ load nếu có trong file checkpoint --- if key not in params: print(f"⚠️ Bỏ qua module '{key}' (không tìm thấy trong checkpoint - OK với model inference)") continue # --------------------------------------------------------------- state_dict = params[key] new_state_dict = {} # Xử lý prefix "module." nếu train bằng DataParallel for k, v in state_dict.items(): if k.startswith("module."): new_state_dict[k[len("module."):]] = v else: new_state_dict[k] = v self.model[key].load_state_dict(new_state_dict, strict=True) self.model[key].eval().to(self.device) print(f"✅ Loaded module: {key}") # 5. Khởi tạo Sampler cho Diffusion self.sampler = DiffusionSampler( self.model.diffusion.diffusion, sampler=ADPM2Sampler(), sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), clamp=False ) print("Model initialization complete.\n") def preprocess_audio(self, audio_path): """Chuyển đổi audio reference thành Style Vector""" wave, sr = librosa.load(audio_path, sr=24000) audio, _ = librosa.effects.trim(wave, top_db=30) to_mel = torchaudio.transforms.MelSpectrogram( n_mels=80, n_fft=2048, win_length=1200, hop_length=300 ) mel = to_mel(torch.from_numpy(audio).float()) mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4 mel = mel.to(self.device) with torch.no_grad(): ref_s = self.model.style_encoder(mel.unsqueeze(1)) ref_p = self.model.predictor_encoder(mel.unsqueeze(1)) ref_style = torch.cat([ref_s, ref_p], dim=1) return ref_style def preprocess_text(self, text): """Phonemize và Tokenize văn bản""" text = text.strip() if not text: return None ps = self.phonemizer.phonemize([text])[0] tokens = torch.LongTensor(self.text_cleaner(ps)).to(self.device).unsqueeze(0) # Thêm token start/padding tokens = torch.cat([torch.LongTensor([0]).to(self.device).unsqueeze(0), tokens], dim=-1) return tokens def inference(self, text, ref_style, diffusion_steps=5, alpha=0.3, beta=0.7): """Hàm suy luận cốt lõi""" tokens = self.preprocess_text(text) if tokens is None: return None input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device) text_mask = length_to_mask(input_lengths).to(self.device) with torch.no_grad(): # Text encoding & BERT t_en = self.model.text_encoder(tokens, input_lengths, text_mask) bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int()) d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) # Diffusion Sampling (Tạo style vector đa dạng) s_pred = self.sampler( noise=torch.randn((1, 256)).unsqueeze(1).to(self.device), embedding=bert_dur, features=ref_style, num_steps=diffusion_steps ).squeeze(1) # Trộn style dự đoán và style gốc (Ref audio) # alpha: trọng số giữ lại của style dự đoán (càng cao càng đa dạng nhưng có thể lệch giọng) # beta: trọng số giữ lại của style gốc (càng cao càng giống giọng mẫu) s = s_pred[:, 128:] * alpha + ref_style[:, 128:] * beta ref = s_pred[:, :128] * alpha + ref_style[:, :128] * beta # Predictor (Dự đoán Duration, F0, N) d = self.model.predictor.text_encoder(d_en, s, input_lengths, text_mask) x, _ = self.model.predictor.lstm(d) duration = torch.sigmoid(self.model.predictor.duration_proj(x)).sum(axis=-1) pred_dur = torch.round(duration.squeeze()).clamp(min=1) # Alignment Map Construction pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) c_frame = 0 for i in range(pred_aln_trg.size(0)): pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 c_frame += int(pred_dur[i].data) # Decoder (Sinh âm thanh) en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device)) F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s) asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device)) out = self.model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) return out.squeeze().cpu().numpy()[..., :-50] # Cắt bớt đuôi silence def generate_long_text(self, text, ref_audio_path): """Xử lý văn bản dài bằng cách tách câu""" print(f"Processing audio ref: {ref_audio_path}") ref_style = self.preprocess_audio(ref_audio_path) # Tách câu đơn giản (có thể cải thiện bằng nltk nếu cần) sentences = text.split('.') wavs = [] start_time = time.time() print("Start synthesizing...") for sent in sentences: if len(sent.strip()) == 0: continue # Thêm dấu chấm để ngắt nghỉ tự nhiên hơn nếu phonemizer cần if not sent.strip().endswith('.'): sent += '.' wav = self.inference(sent, ref_style) if wav is not None: wavs.append(wav) # Thêm khoảng lặng ngắn giữa các câu (0.1s) silence = np.zeros(int(24000 * 0.1)) wavs.append(silence) full_wav = np.concatenate(wavs) print(f"Done! Total time: {time.time() - start_time:.2f}s") return full_wav # ================= MAIN EXECUTION ================= if __name__ == "__main__": # 1. Khởi tạo tts = StyleTTS2Inference(CONFIG_PATH, MODEL_PATH) # 2. Danh sách văn bản cần đọc list_texts = ["xin chào việt nam, hôm nay trời rất đẹp"] full_audio = [] # 3. Chạy vòng lặp tạo giọng for text in list_texts: audio_segment = tts.generate_long_text(text, REF_AUDIO_PATH) full_audio.append(audio_segment) # Thêm khoảng lặng giữa các đoạn văn lớn (0.5s) full_audio.append(np.zeros(int(24000 * 0.5))) # 4. Lưu file kết quả final_wav = np.concatenate(full_audio) sf.write(OUTPUT_WAV, final_wav, 24000) print(f"File saved to: {OUTPUT_WAV}")